mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-24 00:33:57 +00:00
revise kernel
and add unused files
This commit is contained in:
@@ -56,6 +56,8 @@ parser.add_argument("--cuda-malloc", action="store_true")
|
||||
parser.add_argument("--cuda-stream", action="store_true")
|
||||
parser.add_argument("--pin-shared-memory", action="store_true")
|
||||
|
||||
parser.add_argument("--i-am-lllyasviel", action="store_true")
|
||||
|
||||
args = parser.parse_known_args()[0]
|
||||
|
||||
# Some dynamic args that may be changed by webui rather than cmd flags.
|
||||
|
||||
@@ -34,7 +34,6 @@ class ForgeDiffusionEngine:
|
||||
self.first_stage_model = None # set this so that you can change VAE in UI
|
||||
|
||||
# WebUI Dirty Legacy
|
||||
self.latent_channels = 4
|
||||
self.is_sd1 = False
|
||||
self.is_sd2 = False
|
||||
self.is_sdxl = False
|
||||
|
||||
103
backend/diffusion_engine/flux.py
Normal file
103
backend/diffusion_engine/flux.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import torch
|
||||
|
||||
from huggingface_guess import model_list
|
||||
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
|
||||
from backend.patcher.clip import CLIP
|
||||
from backend.patcher.vae import VAE
|
||||
from backend.patcher.unet import UnetPatcher
|
||||
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
||||
from backend.text_processing.t5_engine import T5TextProcessingEngine
|
||||
from backend.args import dynamic_args, args
|
||||
from backend.modules.k_prediction import PredictionFlux
|
||||
from backend import memory_management
|
||||
|
||||
|
||||
class Flux(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.Flux]
|
||||
|
||||
def __init__(self, estimated_config, huggingface_components):
|
||||
if not args.i_am_lllyasviel:
|
||||
raise NotImplementedError('Flux is not implemented yet!')
|
||||
|
||||
super().__init__(estimated_config, huggingface_components)
|
||||
self.is_inpaint = False
|
||||
|
||||
clip = CLIP(
|
||||
model_dict={
|
||||
'clip_l': huggingface_components['text_encoder'],
|
||||
't5xxl': huggingface_components['text_encoder_2']
|
||||
},
|
||||
tokenizer_dict={
|
||||
'clip_l': huggingface_components['tokenizer'],
|
||||
't5xxl': huggingface_components['tokenizer_2']
|
||||
}
|
||||
)
|
||||
|
||||
vae = VAE(model=huggingface_components['vae'])
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['transformer'],
|
||||
diffusers_scheduler=None,
|
||||
k_predictor=PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.0, timesteps=10000)
|
||||
)
|
||||
|
||||
self.text_processing_engine_l = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_l,
|
||||
tokenizer=clip.tokenizer.clip_l,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_l',
|
||||
embedding_expected_shape=768,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=False,
|
||||
minimal_clip_skip=1,
|
||||
clip_skip=1,
|
||||
return_pooled=True,
|
||||
final_layer_norm=True,
|
||||
)
|
||||
|
||||
self.text_processing_engine_t5 = T5TextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.t5xxl,
|
||||
tokenizer=clip.tokenizer.t5xxl,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
)
|
||||
|
||||
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
||||
self.forge_objects_original = self.forge_objects.shallow_copy()
|
||||
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
||||
|
||||
# WebUI Legacy
|
||||
self.first_stage_model = vae.first_stage_model
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine_l.clip_skip = clip_skip
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||
cond_l, pooled_l = self.text_processing_engine_l(prompt)
|
||||
cond_t5 = self.text_processing_engine_t5(prompt)
|
||||
|
||||
cond = dict(
|
||||
crossattn=cond_t5,
|
||||
vector=pooled_l,
|
||||
guidance=torch.FloatTensor([3.5] * len(prompt))
|
||||
)
|
||||
|
||||
return cond
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
_, token_count = self.text_processing_engine_t5.process_texts([prompt])
|
||||
return token_count, max(255, token_count)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
|
||||
return sample.to(x)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
return sample.to(x)
|
||||
@@ -19,11 +19,10 @@ from backend.nn.unet import IntegratedUNet2DConditionModel
|
||||
from backend.diffusion_engine.sd15 import StableDiffusion
|
||||
from backend.diffusion_engine.sd20 import StableDiffusion2
|
||||
from backend.diffusion_engine.sdxl import StableDiffusionXL
|
||||
from backend.diffusion_engine.flux import Flux
|
||||
|
||||
|
||||
possible_models = [
|
||||
StableDiffusion, StableDiffusion2, StableDiffusionXL,
|
||||
]
|
||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, Flux]
|
||||
|
||||
|
||||
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||
@@ -65,25 +64,25 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
], log_name=cls_name)
|
||||
|
||||
return model
|
||||
# if cls_name == 'T5EncoderModel':
|
||||
# from backend.nn.t5 import IntegratedT5
|
||||
# config = read_arbitrary_config(config_path)
|
||||
#
|
||||
# dtype = memory_management.text_encoder_dtype()
|
||||
# sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype
|
||||
# need_cast = False
|
||||
#
|
||||
# if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
# dtype = sd_dtype
|
||||
# need_cast = True
|
||||
#
|
||||
# with modeling_utils.no_init_weights():
|
||||
# with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=need_cast):
|
||||
# model = IntegratedT5(config)
|
||||
#
|
||||
# load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight'])
|
||||
#
|
||||
# return model
|
||||
if cls_name == 'T5EncoderModel':
|
||||
from backend.nn.t5 import IntegratedT5
|
||||
config = read_arbitrary_config(config_path)
|
||||
|
||||
dtype = memory_management.text_encoder_dtype()
|
||||
sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype
|
||||
need_cast = False
|
||||
|
||||
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
dtype = sd_dtype
|
||||
need_cast = True
|
||||
|
||||
with modeling_utils.no_init_weights():
|
||||
with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=need_cast):
|
||||
model = IntegratedT5(config)
|
||||
|
||||
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight'])
|
||||
|
||||
return model
|
||||
if cls_name == 'UNet2DConditionModel':
|
||||
unet_config = guess.unet_config.copy()
|
||||
state_dict_size = memory_management.state_dict_size(state_dict)
|
||||
@@ -97,20 +96,20 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
|
||||
load_state_dict(model, state_dict)
|
||||
return model
|
||||
# if cls_name == 'FluxTransformer2DModel':
|
||||
# from backend.nn.flux import IntegratedFluxTransformer2DModel
|
||||
# unet_config = guess.unet_config.copy()
|
||||
# state_dict_size = memory_management.state_dict_size(state_dict)
|
||||
# ini_dtype = memory_management.unet_dtype(model_params=state_dict_size)
|
||||
# ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype)
|
||||
# to_args = dict(device=ini_device, dtype=ini_dtype)
|
||||
#
|
||||
# with using_forge_operations(**to_args):
|
||||
# model = IntegratedFluxTransformer2DModel(**unet_config).to(**to_args)
|
||||
# model.config = unet_config
|
||||
#
|
||||
# load_state_dict(model, state_dict)
|
||||
# return model
|
||||
if cls_name == 'FluxTransformer2DModel':
|
||||
from backend.nn.flux import IntegratedFluxTransformer2DModel
|
||||
unet_config = guess.unet_config.copy()
|
||||
state_dict_size = memory_management.state_dict_size(state_dict)
|
||||
ini_dtype = memory_management.unet_dtype(model_params=state_dict_size)
|
||||
ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype)
|
||||
to_args = dict(device=ini_device, dtype=ini_dtype)
|
||||
|
||||
with using_forge_operations(**to_args):
|
||||
model = IntegratedFluxTransformer2DModel(**unet_config).to(**to_args)
|
||||
model.config = unet_config
|
||||
|
||||
load_state_dict(model, state_dict)
|
||||
return model
|
||||
|
||||
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
|
||||
return None
|
||||
|
||||
326
backend/nn/flux.py
Normal file
326
backend/nn/flux.py
Normal file
@@ -0,0 +1,326 @@
|
||||
# Single File Implementation of Flux, by Forge
|
||||
# See also https://github.com/black-forest-labs/flux
|
||||
|
||||
|
||||
import math
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
from dataclasses import dataclass
|
||||
from backend.attention import attention_function
|
||||
|
||||
|
||||
def attention(q, k, v, pe):
|
||||
q, k = apply_rope(q, k, pe)
|
||||
x = attention_function(q, k, v, q.shape[1], skip_reshape=True)
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos, dim, theta):
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta ** scale)
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
|
||||
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
def timestep_embedding(t, dim, max_period=10000, time_factor=1000.0):
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(t)
|
||||
return embedding
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: torch.Tensor
|
||||
scale: torch.Tensor
|
||||
gate: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list[int]
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim, theta, axes_dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids):
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
to_args = dict(device=x.device, dtype=x.dtype)
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms.to(x) * self.scale.to(x)).to(**to_args)
|
||||
|
||||
|
||||
class QKNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim)
|
||||
self.key_norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, q, k, v):
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q.to(v), k.to(v)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x, pe):
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim, double):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||
|
||||
def forward(self, vec):
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
return (
|
||||
ModulationOut(*out[:3]),
|
||||
ModulationOut(*out[3:]) if self.is_double else None,
|
||||
)
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio, qkv_bias=False):
|
||||
super().__init__()
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True)
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
self.txt_mod = Modulation(hidden_size, double=True)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, img, txt, vec, pe):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, qk_scale=None):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False)
|
||||
|
||||
def forward(self, x, vec, pe):
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size, patch_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x, vec):
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class IntegratedFluxTransformer2DModel(nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
params = FluxParams(**kwargs)
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels * 4
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
def inner_forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None):
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = img[:, txt.shape[1]:, ...]
|
||||
img = self.final_layer(img, vec)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
input_device = x.device
|
||||
input_dtype = x.dtype
|
||||
patch_size = 2
|
||||
pad_h = (patch_size - x.shape[-2] % patch_size) % patch_size
|
||||
pad_w = (patch_size - x.shape[-1] % patch_size) % patch_size
|
||||
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="circular")
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=input_device, dtype=input_dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=input_device, dtype=input_dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=input_device, dtype=input_dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=input_device, dtype=input_dtype)
|
||||
out = self.inner_forward(img, img_ids, context, txt_ids, timestep, y, guidance)
|
||||
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w]
|
||||
return out
|
||||
212
backend/nn/t5.py
Normal file
212
backend/nn/t5.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import torch
|
||||
import math
|
||||
|
||||
from backend.attention import attention_function
|
||||
|
||||
|
||||
activations = {
|
||||
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
||||
"relu": torch.nn.functional.relu,
|
||||
}
|
||||
|
||||
|
||||
class T5LayerNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight.to(x) * x
|
||||
|
||||
|
||||
class T5DenseActDense(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, ff_activation):
|
||||
super().__init__()
|
||||
self.wi = torch.nn.Linear(model_dim, ff_dim, bias=False)
|
||||
self.wo = torch.nn.Linear(ff_dim, model_dim, bias=False)
|
||||
self.act = activations[ff_activation]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.act(self.wi(x))
|
||||
x = self.wo(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5DenseGatedActDense(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, ff_activation):
|
||||
super().__init__()
|
||||
self.wi_0 = torch.nn.Linear(model_dim, ff_dim, bias=False)
|
||||
self.wi_1 = torch.nn.Linear(model_dim, ff_dim, bias=False)
|
||||
self.wo = torch.nn.Linear(ff_dim, model_dim, bias=False)
|
||||
self.act = activations[ff_activation]
|
||||
|
||||
def forward(self, x):
|
||||
hidden_gelu = self.act(self.wi_0(x))
|
||||
hidden_linear = self.wi_1(x)
|
||||
x = hidden_gelu * hidden_linear
|
||||
x = self.wo(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5LayerFF(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, ff_activation, gated_act):
|
||||
super().__init__()
|
||||
if gated_act:
|
||||
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, ff_activation)
|
||||
else:
|
||||
self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, ff_activation)
|
||||
|
||||
self.layer_norm = T5LayerNorm(model_dim)
|
||||
|
||||
def forward(self, x):
|
||||
forwarded_states = self.layer_norm(x)
|
||||
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||
x += forwarded_states
|
||||
return x
|
||||
|
||||
|
||||
class T5Attention(torch.nn.Module):
|
||||
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias):
|
||||
super().__init__()
|
||||
self.q = torch.nn.Linear(model_dim, inner_dim, bias=False)
|
||||
self.k = torch.nn.Linear(model_dim, inner_dim, bias=False)
|
||||
self.v = torch.nn.Linear(model_dim, inner_dim, bias=False)
|
||||
self.o = torch.nn.Linear(inner_dim, model_dim, bias=False)
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.relative_attention_bias = None
|
||||
if relative_attention_bias:
|
||||
self.relative_attention_num_buckets = 32
|
||||
self.relative_attention_max_distance = 128
|
||||
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
||||
relative_buckets = 0
|
||||
if bidirectional:
|
||||
num_buckets //= 2
|
||||
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
||||
relative_position = torch.abs(relative_position)
|
||||
else:
|
||||
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
|
||||
|
||||
max_exact = num_buckets // 2
|
||||
is_small = relative_position < max_exact
|
||||
|
||||
relative_position_if_large = max_exact + (
|
||||
torch.log(relative_position.float() / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).to(torch.long)
|
||||
relative_position_if_large = torch.min(
|
||||
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
|
||||
)
|
||||
|
||||
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
||||
return relative_buckets
|
||||
|
||||
def compute_bias(self, query_length, key_length, device, dtype):
|
||||
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
||||
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
||||
relative_position = memory_position - context_position
|
||||
relative_position_bucket = self._relative_position_bucket(
|
||||
relative_position,
|
||||
bidirectional=True,
|
||||
num_buckets=self.relative_attention_num_buckets,
|
||||
max_distance=self.relative_attention_max_distance,
|
||||
)
|
||||
values = self.relative_attention_bias(relative_position_bucket).to(dtype)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0)
|
||||
return values
|
||||
|
||||
def forward(self, x, mask=None, past_bias=None):
|
||||
q = self.q(x)
|
||||
k = self.k(x)
|
||||
v = self.v(x)
|
||||
if self.relative_attention_bias is not None:
|
||||
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device, x.dtype)
|
||||
|
||||
if past_bias is not None:
|
||||
if mask is not None:
|
||||
mask = mask + past_bias
|
||||
else:
|
||||
mask = past_bias
|
||||
|
||||
out = attention_function(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask)
|
||||
return self.o(out), past_bias
|
||||
|
||||
|
||||
class T5LayerSelfAttention(torch.nn.Module):
|
||||
def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias):
|
||||
super().__init__()
|
||||
self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias)
|
||||
self.layer_norm = T5LayerNorm(model_dim)
|
||||
|
||||
def forward(self, x, mask=None, past_bias=None):
|
||||
output, past_bias = self.SelfAttention(self.layer_norm(x), mask=mask, past_bias=past_bias)
|
||||
x += output
|
||||
return x, past_bias
|
||||
|
||||
|
||||
class T5Block(torch.nn.Module):
|
||||
def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList()
|
||||
self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias))
|
||||
self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, gated_act))
|
||||
|
||||
def forward(self, x, mask=None, past_bias=None):
|
||||
x, past_bias = self.layer[0](x, mask, past_bias)
|
||||
x = self.layer[-1](x)
|
||||
return x, past_bias
|
||||
|
||||
|
||||
class T5Stack(torch.nn.Module):
|
||||
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention):
|
||||
super().__init__()
|
||||
|
||||
self.block = torch.nn.ModuleList(
|
||||
[T5Block(model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias=((not relative_attention) or (i == 0))) for i in range(num_layers)]
|
||||
)
|
||||
self.final_layer_norm = T5LayerNorm(model_dim)
|
||||
|
||||
def forward(self, x, attention_mask=None):
|
||||
mask = None
|
||||
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
|
||||
past_bias = None
|
||||
|
||||
for i, l in enumerate(self.block):
|
||||
x, past_bias = l(x, mask, past_bias)
|
||||
|
||||
x = self.final_layer_norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5(torch.nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_layers = config["num_layers"]
|
||||
model_dim = config["d_model"]
|
||||
|
||||
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config["d_ff"], config["dense_act_fn"], config["is_gated_act"], config["num_heads"], config["model_type"] != "umt5")
|
||||
self.shared = torch.nn.Embedding(config["vocab_size"], model_dim)
|
||||
|
||||
def forward(self, input_ids, *args, **kwargs):
|
||||
x = self.shared(input_ids)
|
||||
x = torch.nan_to_num(x)
|
||||
return self.encoder(x, *args, **kwargs)
|
||||
|
||||
|
||||
class IntegratedT5(torch.nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.transformer = T5(config)
|
||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||
@@ -22,14 +22,16 @@ def weights_manual_cast(layer, x, skip_dtype=False):
|
||||
|
||||
if stream.using_stream:
|
||||
with stream.stream_context()(stream.mover_stream):
|
||||
if layer.weight is not None:
|
||||
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
if layer.bias is not None:
|
||||
bias = layer.bias.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
signal = stream.mover_stream.record_event()
|
||||
else:
|
||||
if layer.weight is not None:
|
||||
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
if layer.bias is not None:
|
||||
bias = layer.bias.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||
|
||||
return weight, bias, signal
|
||||
|
||||
|
||||
@@ -110,6 +110,9 @@ def compile_conditions(cond):
|
||||
)
|
||||
)
|
||||
|
||||
if 'guidance' in cond:
|
||||
result['model_conds']['guidance'] = Condition(cond['guidance'])
|
||||
|
||||
return [result, ]
|
||||
|
||||
|
||||
|
||||
@@ -890,7 +890,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
|
||||
latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C)
|
||||
latent_channels = shared.sd_model.forge_objects.vae.latent_channels
|
||||
p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
|
||||
|
||||
if p.scripts is not None:
|
||||
|
||||
@@ -318,17 +318,19 @@ def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_s
|
||||
|
||||
|
||||
def stack_conds(tensors):
|
||||
# if prompts have wildly different lengths above the limit we'll get tensors of different shapes
|
||||
# and won't be able to torch.stack them. So this fixes that.
|
||||
token_count = max([x.shape[0] for x in tensors])
|
||||
for i in range(len(tensors)):
|
||||
if tensors[i].shape[0] != token_count:
|
||||
last_vector = tensors[i][-1:]
|
||||
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
|
||||
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
|
||||
|
||||
return torch.stack(tensors)
|
||||
|
||||
try:
|
||||
result = torch.stack(tensors)
|
||||
except:
|
||||
# if prompts have wildly different lengths above the limit we'll get tensors of different shapes
|
||||
# and won't be able to torch.stack them. So this fixes that.
|
||||
token_count = max([x.shape[0] for x in tensors])
|
||||
for i in range(len(tensors)):
|
||||
if tensors[i].shape[0] != token_count:
|
||||
last_vector = tensors[i][-1:]
|
||||
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
|
||||
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
|
||||
result = torch.stack(tensors)
|
||||
return result
|
||||
|
||||
|
||||
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
||||
|
||||
@@ -58,7 +58,7 @@ def model():
|
||||
model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
|
||||
download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)
|
||||
|
||||
loaded_model = VAEApprox(latent_channels=shared.sd_model.latent_channels)
|
||||
loaded_model = VAEApprox(latent_channels=shared.sd_model.forge_objects.vae.latent_channels)
|
||||
loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
loaded_model.eval()
|
||||
loaded_model.to(devices.device, devices.dtype)
|
||||
|
||||
Reference in New Issue
Block a user