From a6baf4a4b54d573f76141c21d40d508c0fc447d1 Mon Sep 17 00:00:00 2001 From: lllyasviel <19834515+lllyasviel@users.noreply.github.com> Date: Wed, 7 Aug 2024 16:51:24 -0700 Subject: [PATCH] revise kernel and add unused files --- backend/args.py | 2 + backend/diffusion_engine/base.py | 1 - backend/diffusion_engine/flux.py | 103 ++++++++++ backend/loader.py | 71 ++++--- backend/nn/flux.py | 326 +++++++++++++++++++++++++++++++ backend/nn/t5.py | 212 ++++++++++++++++++++ backend/operations.py | 6 +- backend/sampling/condition.py | 3 + modules/processing.py | 2 +- modules/prompt_parser.py | 24 +-- modules/sd_vae_approx.py | 2 +- 11 files changed, 700 insertions(+), 52 deletions(-) create mode 100644 backend/diffusion_engine/flux.py create mode 100644 backend/nn/flux.py create mode 100644 backend/nn/t5.py diff --git a/backend/args.py b/backend/args.py index 0302dfc1..3254b7f8 100644 --- a/backend/args.py +++ b/backend/args.py @@ -56,6 +56,8 @@ parser.add_argument("--cuda-malloc", action="store_true") parser.add_argument("--cuda-stream", action="store_true") parser.add_argument("--pin-shared-memory", action="store_true") +parser.add_argument("--i-am-lllyasviel", action="store_true") + args = parser.parse_known_args()[0] # Some dynamic args that may be changed by webui rather than cmd flags. diff --git a/backend/diffusion_engine/base.py b/backend/diffusion_engine/base.py index 76048020..5ad5ac2f 100644 --- a/backend/diffusion_engine/base.py +++ b/backend/diffusion_engine/base.py @@ -34,7 +34,6 @@ class ForgeDiffusionEngine: self.first_stage_model = None # set this so that you can change VAE in UI # WebUI Dirty Legacy - self.latent_channels = 4 self.is_sd1 = False self.is_sd2 = False self.is_sdxl = False diff --git a/backend/diffusion_engine/flux.py b/backend/diffusion_engine/flux.py new file mode 100644 index 00000000..886ee23c --- /dev/null +++ b/backend/diffusion_engine/flux.py @@ -0,0 +1,103 @@ +import torch + +from huggingface_guess import model_list +from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects +from backend.patcher.clip import CLIP +from backend.patcher.vae import VAE +from backend.patcher.unet import UnetPatcher +from backend.text_processing.classic_engine import ClassicTextProcessingEngine +from backend.text_processing.t5_engine import T5TextProcessingEngine +from backend.args import dynamic_args, args +from backend.modules.k_prediction import PredictionFlux +from backend import memory_management + + +class Flux(ForgeDiffusionEngine): + matched_guesses = [model_list.Flux] + + def __init__(self, estimated_config, huggingface_components): + if not args.i_am_lllyasviel: + raise NotImplementedError('Flux is not implemented yet!') + + super().__init__(estimated_config, huggingface_components) + self.is_inpaint = False + + clip = CLIP( + model_dict={ + 'clip_l': huggingface_components['text_encoder'], + 't5xxl': huggingface_components['text_encoder_2'] + }, + tokenizer_dict={ + 'clip_l': huggingface_components['tokenizer'], + 't5xxl': huggingface_components['tokenizer_2'] + } + ) + + vae = VAE(model=huggingface_components['vae']) + + unet = UnetPatcher.from_model( + model=huggingface_components['transformer'], + diffusers_scheduler=None, + k_predictor=PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.0, timesteps=10000) + ) + + self.text_processing_engine_l = ClassicTextProcessingEngine( + text_encoder=clip.cond_stage_model.clip_l, + tokenizer=clip.tokenizer.clip_l, + embedding_dir=dynamic_args['embedding_dir'], + embedding_key='clip_l', + embedding_expected_shape=768, + emphasis_name=dynamic_args['emphasis_name'], + text_projection=False, + minimal_clip_skip=1, + clip_skip=1, + return_pooled=True, + final_layer_norm=True, + ) + + self.text_processing_engine_t5 = T5TextProcessingEngine( + text_encoder=clip.cond_stage_model.t5xxl, + tokenizer=clip.tokenizer.t5xxl, + emphasis_name=dynamic_args['emphasis_name'], + ) + + self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None) + self.forge_objects_original = self.forge_objects.shallow_copy() + self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy() + + # WebUI Legacy + self.first_stage_model = vae.first_stage_model + + def set_clip_skip(self, clip_skip): + self.text_processing_engine_l.clip_skip = clip_skip + + @torch.inference_mode() + def get_learned_conditioning(self, prompt: list[str]): + memory_management.load_model_gpu(self.forge_objects.clip.patcher) + cond_l, pooled_l = self.text_processing_engine_l(prompt) + cond_t5 = self.text_processing_engine_t5(prompt) + + cond = dict( + crossattn=cond_t5, + vector=pooled_l, + guidance=torch.FloatTensor([3.5] * len(prompt)) + ) + + return cond + + @torch.inference_mode() + def get_prompt_lengths_on_ui(self, prompt): + _, token_count = self.text_processing_engine_t5.process_texts([prompt]) + return token_count, max(255, token_count) + + @torch.inference_mode() + def encode_first_stage(self, x): + sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5) + sample = self.forge_objects.vae.first_stage_model.process_in(sample) + return sample.to(x) + + @torch.inference_mode() + def decode_first_stage(self, x): + sample = self.forge_objects.vae.first_stage_model.process_out(x) + sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0 + return sample.to(x) diff --git a/backend/loader.py b/backend/loader.py index cdb224da..d5ddc10c 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -19,11 +19,10 @@ from backend.nn.unet import IntegratedUNet2DConditionModel from backend.diffusion_engine.sd15 import StableDiffusion from backend.diffusion_engine.sd20 import StableDiffusion2 from backend.diffusion_engine.sdxl import StableDiffusionXL +from backend.diffusion_engine.flux import Flux -possible_models = [ - StableDiffusion, StableDiffusion2, StableDiffusionXL, -] +possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, Flux] logging.getLogger("diffusers").setLevel(logging.ERROR) @@ -65,25 +64,25 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p ], log_name=cls_name) return model - # if cls_name == 'T5EncoderModel': - # from backend.nn.t5 import IntegratedT5 - # config = read_arbitrary_config(config_path) - # - # dtype = memory_management.text_encoder_dtype() - # sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype - # need_cast = False - # - # if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - # dtype = sd_dtype - # need_cast = True - # - # with modeling_utils.no_init_weights(): - # with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=need_cast): - # model = IntegratedT5(config) - # - # load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight']) - # - # return model + if cls_name == 'T5EncoderModel': + from backend.nn.t5 import IntegratedT5 + config = read_arbitrary_config(config_path) + + dtype = memory_management.text_encoder_dtype() + sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype + need_cast = False + + if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + dtype = sd_dtype + need_cast = True + + with modeling_utils.no_init_weights(): + with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=need_cast): + model = IntegratedT5(config) + + load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight']) + + return model if cls_name == 'UNet2DConditionModel': unet_config = guess.unet_config.copy() state_dict_size = memory_management.state_dict_size(state_dict) @@ -97,20 +96,20 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p load_state_dict(model, state_dict) return model - # if cls_name == 'FluxTransformer2DModel': - # from backend.nn.flux import IntegratedFluxTransformer2DModel - # unet_config = guess.unet_config.copy() - # state_dict_size = memory_management.state_dict_size(state_dict) - # ini_dtype = memory_management.unet_dtype(model_params=state_dict_size) - # ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype) - # to_args = dict(device=ini_device, dtype=ini_dtype) - # - # with using_forge_operations(**to_args): - # model = IntegratedFluxTransformer2DModel(**unet_config).to(**to_args) - # model.config = unet_config - # - # load_state_dict(model, state_dict) - # return model + if cls_name == 'FluxTransformer2DModel': + from backend.nn.flux import IntegratedFluxTransformer2DModel + unet_config = guess.unet_config.copy() + state_dict_size = memory_management.state_dict_size(state_dict) + ini_dtype = memory_management.unet_dtype(model_params=state_dict_size) + ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype) + to_args = dict(device=ini_device, dtype=ini_dtype) + + with using_forge_operations(**to_args): + model = IntegratedFluxTransformer2DModel(**unet_config).to(**to_args) + model.config = unet_config + + load_state_dict(model, state_dict) + return model print(f'Skipped: {component_name} = {lib_name}.{cls_name}') return None diff --git a/backend/nn/flux.py b/backend/nn/flux.py new file mode 100644 index 00000000..54800f96 --- /dev/null +++ b/backend/nn/flux.py @@ -0,0 +1,326 @@ +# Single File Implementation of Flux, by Forge +# See also https://github.com/black-forest-labs/flux + + +import math +import torch +from einops import rearrange, repeat +from torch import nn +from dataclasses import dataclass +from backend.attention import attention_function + + +def attention(q, k, v, pe): + q, k = apply_rope(q, k, pe) + x = attention_function(q, k, v, q.shape[1], skip_reshape=True) + return x + + +def rope(pos, dim, theta): + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta ** scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq, xk, freqs_cis): + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + +def timestep_embedding(t, dim, max_period=10000, time_factor=1000.0): + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +@dataclass +class ModulationOut: + shift: torch.Tensor + scale: torch.Tensor + gate: torch.Tensor + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +class EmbedND(nn.Module): + def __init__(self, dim, theta, axes_dim): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids): + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + return emb.unsqueeze(1) + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim, hidden_dim): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x): + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + to_args = dict(device=x.device, dtype=x.dtype) + x = x.float() + rrms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms.to(x) * self.scale.to(x)).to(**to_args) + + +class QKNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q, k, v): + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + def forward(self, x, pe): + qkv = self.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x + + +class Modulation(nn.Module): + def __init__(self, dim, double): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec): + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size, num_heads, mlp_ratio, qkv_bias=False): + super().__init__() + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + def forward(self, img, txt, vec, pe): + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + attn = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + return img, txt + + +class SingleStreamBlock(nn.Module): + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, qk_scale=None): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + self.norm = QKNorm(head_dim) + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + def forward(self, x, vec, pe): + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + attn = attention(q, k, v, pe=pe) + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + mod.gate * output + + +class LastLayer(nn.Module): + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x, vec): + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + + +class IntegratedFluxTransformer2DModel(nn.Module): + def __init__(self, **kwargs): + super().__init__() + params = FluxParams(**kwargs) + self.params = params + self.in_channels = params.in_channels * 4 + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + def inner_forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None): + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1]:, ...] + img = self.final_layer(img, vec) + return img + + def forward(self, x, timestep, context, y, guidance, **kwargs): + bs, c, h, w = x.shape + input_device = x.device + input_dtype = x.dtype + patch_size = 2 + pad_h = (patch_size - x.shape[-2] % patch_size) % patch_size + pad_w = (patch_size - x.shape[-1] % patch_size) % patch_size + x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="circular") + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + h_len = ((h + (patch_size // 2)) // patch_size) + w_len = ((w + (patch_size // 2)) // patch_size) + img_ids = torch.zeros((h_len, w_len, 3), device=input_device, dtype=input_dtype) + img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=input_device, dtype=input_dtype)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=input_device, dtype=input_dtype)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + txt_ids = torch.zeros((bs, context.shape[1], 3), device=input_device, dtype=input_dtype) + out = self.inner_forward(img, img_ids, context, txt_ids, timestep, y, guidance) + out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w] + return out diff --git a/backend/nn/t5.py b/backend/nn/t5.py new file mode 100644 index 00000000..8a9cc9b5 --- /dev/null +++ b/backend/nn/t5.py @@ -0,0 +1,212 @@ +import torch +import math + +from backend.attention import attention_function + + +activations = { + "gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"), + "relu": torch.nn.functional.relu, +} + + +class T5LayerNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + return self.weight.to(x) * x + + +class T5DenseActDense(torch.nn.Module): + def __init__(self, model_dim, ff_dim, ff_activation): + super().__init__() + self.wi = torch.nn.Linear(model_dim, ff_dim, bias=False) + self.wo = torch.nn.Linear(ff_dim, model_dim, bias=False) + self.act = activations[ff_activation] + + def forward(self, x): + x = self.act(self.wi(x)) + x = self.wo(x) + return x + + +class T5DenseGatedActDense(torch.nn.Module): + def __init__(self, model_dim, ff_dim, ff_activation): + super().__init__() + self.wi_0 = torch.nn.Linear(model_dim, ff_dim, bias=False) + self.wi_1 = torch.nn.Linear(model_dim, ff_dim, bias=False) + self.wo = torch.nn.Linear(ff_dim, model_dim, bias=False) + self.act = activations[ff_activation] + + def forward(self, x): + hidden_gelu = self.act(self.wi_0(x)) + hidden_linear = self.wi_1(x) + x = hidden_gelu * hidden_linear + x = self.wo(x) + return x + + +class T5LayerFF(torch.nn.Module): + def __init__(self, model_dim, ff_dim, ff_activation, gated_act): + super().__init__() + if gated_act: + self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, ff_activation) + else: + self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, ff_activation) + + self.layer_norm = T5LayerNorm(model_dim) + + def forward(self, x): + forwarded_states = self.layer_norm(x) + forwarded_states = self.DenseReluDense(forwarded_states) + x += forwarded_states + return x + + +class T5Attention(torch.nn.Module): + def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias): + super().__init__() + self.q = torch.nn.Linear(model_dim, inner_dim, bias=False) + self.k = torch.nn.Linear(model_dim, inner_dim, bias=False) + self.v = torch.nn.Linear(model_dim, inner_dim, bias=False) + self.o = torch.nn.Linear(inner_dim, model_dim, bias=False) + self.num_heads = num_heads + + self.relative_attention_bias = None + if relative_attention_bias: + self.relative_attention_num_buckets = 32 + self.relative_attention_max_distance = 128 + self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device, dtype): + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_position_bucket( + relative_position, + bidirectional=True, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket).to(dtype) + values = values.permute([2, 0, 1]).unsqueeze(0) + return values + + def forward(self, x, mask=None, past_bias=None): + q = self.q(x) + k = self.k(x) + v = self.v(x) + if self.relative_attention_bias is not None: + past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device, x.dtype) + + if past_bias is not None: + if mask is not None: + mask = mask + past_bias + else: + mask = past_bias + + out = attention_function(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask) + return self.o(out), past_bias + + +class T5LayerSelfAttention(torch.nn.Module): + def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias): + super().__init__() + self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias) + self.layer_norm = T5LayerNorm(model_dim) + + def forward(self, x, mask=None, past_bias=None): + output, past_bias = self.SelfAttention(self.layer_norm(x), mask=mask, past_bias=past_bias) + x += output + return x, past_bias + + +class T5Block(torch.nn.Module): + def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias): + super().__init__() + self.layer = torch.nn.ModuleList() + self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias)) + self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, gated_act)) + + def forward(self, x, mask=None, past_bias=None): + x, past_bias = self.layer[0](x, mask, past_bias) + x = self.layer[-1](x) + return x, past_bias + + +class T5Stack(torch.nn.Module): + def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention): + super().__init__() + + self.block = torch.nn.ModuleList( + [T5Block(model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias=((not relative_attention) or (i == 0))) for i in range(num_layers)] + ) + self.final_layer_norm = T5LayerNorm(model_dim) + + def forward(self, x, attention_mask=None): + mask = None + + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) + mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) + + past_bias = None + + for i, l in enumerate(self.block): + x, past_bias = l(x, mask, past_bias) + + x = self.final_layer_norm(x) + return x + + +class T5(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.num_layers = config["num_layers"] + model_dim = config["d_model"] + + self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config["d_ff"], config["dense_act_fn"], config["is_gated_act"], config["num_heads"], config["model_type"] != "umt5") + self.shared = torch.nn.Embedding(config["vocab_size"], model_dim) + + def forward(self, input_ids, *args, **kwargs): + x = self.shared(input_ids) + x = torch.nan_to_num(x) + return self.encoder(x, *args, **kwargs) + + +class IntegratedT5(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.transformer = T5(config) + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) diff --git a/backend/operations.py b/backend/operations.py index 060b42a0..16288e3e 100644 --- a/backend/operations.py +++ b/backend/operations.py @@ -22,14 +22,16 @@ def weights_manual_cast(layer, x, skip_dtype=False): if stream.using_stream: with stream.stream_context()(stream.mover_stream): + if layer.weight is not None: + weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking) if layer.bias is not None: bias = layer.bias.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking) - weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking) signal = stream.mover_stream.record_event() else: + if layer.weight is not None: + weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking) if layer.bias is not None: bias = layer.bias.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking) - weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking) return weight, bias, signal diff --git a/backend/sampling/condition.py b/backend/sampling/condition.py index 00085986..ff6eb7a2 100644 --- a/backend/sampling/condition.py +++ b/backend/sampling/condition.py @@ -110,6 +110,9 @@ def compile_conditions(cond): ) ) + if 'guidance' in cond: + result['model_conds']['guidance'] = Condition(cond['guidance']) + return [result, ] diff --git a/modules/processing.py b/modules/processing.py index 8eedbca4..d35a8303 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -890,7 +890,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] - latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C) + latent_channels = shared.sd_model.forge_objects.vae.latent_channels p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) if p.scripts is not None: diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 70aefbc7..1a1826ca 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -318,17 +318,19 @@ def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_s def stack_conds(tensors): - # if prompts have wildly different lengths above the limit we'll get tensors of different shapes - # and won't be able to torch.stack them. So this fixes that. - token_count = max([x.shape[0] for x in tensors]) - for i in range(len(tensors)): - if tensors[i].shape[0] != token_count: - last_vector = tensors[i][-1:] - last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) - tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) - - return torch.stack(tensors) - + try: + result = torch.stack(tensors) + except: + # if prompts have wildly different lengths above the limit we'll get tensors of different shapes + # and won't be able to torch.stack them. So this fixes that. + token_count = max([x.shape[0] for x in tensors]) + for i in range(len(tensors)): + if tensors[i].shape[0] != token_count: + last_vector = tensors[i][-1:] + last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) + tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) + result = torch.stack(tensors) + return result def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index c5dda743..7f7ff068 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -58,7 +58,7 @@ def model(): model_path = os.path.join(paths.models_path, "VAE-approx", model_name) download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name) - loaded_model = VAEApprox(latent_channels=shared.sd_model.latent_channels) + loaded_model = VAEApprox(latent_channels=shared.sd_model.forge_objects.vae.latent_channels) loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)) loaded_model.eval() loaded_model.to(devices.device, devices.dtype)