revise kernel

and add unused files
This commit is contained in:
lllyasviel
2024-08-07 16:51:24 -07:00
committed by GitHub
parent a07c758658
commit a6baf4a4b5
11 changed files with 700 additions and 52 deletions

View File

@@ -56,6 +56,8 @@ parser.add_argument("--cuda-malloc", action="store_true")
parser.add_argument("--cuda-stream", action="store_true")
parser.add_argument("--pin-shared-memory", action="store_true")
parser.add_argument("--i-am-lllyasviel", action="store_true")
args = parser.parse_known_args()[0]
# Some dynamic args that may be changed by webui rather than cmd flags.

View File

@@ -34,7 +34,6 @@ class ForgeDiffusionEngine:
self.first_stage_model = None # set this so that you can change VAE in UI
# WebUI Dirty Legacy
self.latent_channels = 4
self.is_sd1 = False
self.is_sd2 = False
self.is_sdxl = False

View File

@@ -0,0 +1,103 @@
import torch
from huggingface_guess import model_list
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
from backend.patcher.clip import CLIP
from backend.patcher.vae import VAE
from backend.patcher.unet import UnetPatcher
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
from backend.text_processing.t5_engine import T5TextProcessingEngine
from backend.args import dynamic_args, args
from backend.modules.k_prediction import PredictionFlux
from backend import memory_management
class Flux(ForgeDiffusionEngine):
matched_guesses = [model_list.Flux]
def __init__(self, estimated_config, huggingface_components):
if not args.i_am_lllyasviel:
raise NotImplementedError('Flux is not implemented yet!')
super().__init__(estimated_config, huggingface_components)
self.is_inpaint = False
clip = CLIP(
model_dict={
'clip_l': huggingface_components['text_encoder'],
't5xxl': huggingface_components['text_encoder_2']
},
tokenizer_dict={
'clip_l': huggingface_components['tokenizer'],
't5xxl': huggingface_components['tokenizer_2']
}
)
vae = VAE(model=huggingface_components['vae'])
unet = UnetPatcher.from_model(
model=huggingface_components['transformer'],
diffusers_scheduler=None,
k_predictor=PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.0, timesteps=10000)
)
self.text_processing_engine_l = ClassicTextProcessingEngine(
text_encoder=clip.cond_stage_model.clip_l,
tokenizer=clip.tokenizer.clip_l,
embedding_dir=dynamic_args['embedding_dir'],
embedding_key='clip_l',
embedding_expected_shape=768,
emphasis_name=dynamic_args['emphasis_name'],
text_projection=False,
minimal_clip_skip=1,
clip_skip=1,
return_pooled=True,
final_layer_norm=True,
)
self.text_processing_engine_t5 = T5TextProcessingEngine(
text_encoder=clip.cond_stage_model.t5xxl,
tokenizer=clip.tokenizer.t5xxl,
emphasis_name=dynamic_args['emphasis_name'],
)
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
self.forge_objects_original = self.forge_objects.shallow_copy()
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
# WebUI Legacy
self.first_stage_model = vae.first_stage_model
def set_clip_skip(self, clip_skip):
self.text_processing_engine_l.clip_skip = clip_skip
@torch.inference_mode()
def get_learned_conditioning(self, prompt: list[str]):
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
cond_l, pooled_l = self.text_processing_engine_l(prompt)
cond_t5 = self.text_processing_engine_t5(prompt)
cond = dict(
crossattn=cond_t5,
vector=pooled_l,
guidance=torch.FloatTensor([3.5] * len(prompt))
)
return cond
@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):
_, token_count = self.text_processing_engine_t5.process_texts([prompt])
return token_count, max(255, token_count)
@torch.inference_mode()
def encode_first_stage(self, x):
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
return sample.to(x)
@torch.inference_mode()
def decode_first_stage(self, x):
sample = self.forge_objects.vae.first_stage_model.process_out(x)
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
return sample.to(x)

View File

@@ -19,11 +19,10 @@ from backend.nn.unet import IntegratedUNet2DConditionModel
from backend.diffusion_engine.sd15 import StableDiffusion
from backend.diffusion_engine.sd20 import StableDiffusion2
from backend.diffusion_engine.sdxl import StableDiffusionXL
from backend.diffusion_engine.flux import Flux
possible_models = [
StableDiffusion, StableDiffusion2, StableDiffusionXL,
]
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, Flux]
logging.getLogger("diffusers").setLevel(logging.ERROR)
@@ -65,25 +64,25 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
], log_name=cls_name)
return model
# if cls_name == 'T5EncoderModel':
# from backend.nn.t5 import IntegratedT5
# config = read_arbitrary_config(config_path)
#
# dtype = memory_management.text_encoder_dtype()
# sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype
# need_cast = False
#
# if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
# dtype = sd_dtype
# need_cast = True
#
# with modeling_utils.no_init_weights():
# with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=need_cast):
# model = IntegratedT5(config)
#
# load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight'])
#
# return model
if cls_name == 'T5EncoderModel':
from backend.nn.t5 import IntegratedT5
config = read_arbitrary_config(config_path)
dtype = memory_management.text_encoder_dtype()
sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype
need_cast = False
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
dtype = sd_dtype
need_cast = True
with modeling_utils.no_init_weights():
with using_forge_operations(device=memory_management.cpu, dtype=dtype, manual_cast_enabled=need_cast):
model = IntegratedT5(config)
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight'])
return model
if cls_name == 'UNet2DConditionModel':
unet_config = guess.unet_config.copy()
state_dict_size = memory_management.state_dict_size(state_dict)
@@ -97,20 +96,20 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
load_state_dict(model, state_dict)
return model
# if cls_name == 'FluxTransformer2DModel':
# from backend.nn.flux import IntegratedFluxTransformer2DModel
# unet_config = guess.unet_config.copy()
# state_dict_size = memory_management.state_dict_size(state_dict)
# ini_dtype = memory_management.unet_dtype(model_params=state_dict_size)
# ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype)
# to_args = dict(device=ini_device, dtype=ini_dtype)
#
# with using_forge_operations(**to_args):
# model = IntegratedFluxTransformer2DModel(**unet_config).to(**to_args)
# model.config = unet_config
#
# load_state_dict(model, state_dict)
# return model
if cls_name == 'FluxTransformer2DModel':
from backend.nn.flux import IntegratedFluxTransformer2DModel
unet_config = guess.unet_config.copy()
state_dict_size = memory_management.state_dict_size(state_dict)
ini_dtype = memory_management.unet_dtype(model_params=state_dict_size)
ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype)
to_args = dict(device=ini_device, dtype=ini_dtype)
with using_forge_operations(**to_args):
model = IntegratedFluxTransformer2DModel(**unet_config).to(**to_args)
model.config = unet_config
load_state_dict(model, state_dict)
return model
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
return None

326
backend/nn/flux.py Normal file
View File

@@ -0,0 +1,326 @@
# Single File Implementation of Flux, by Forge
# See also https://github.com/black-forest-labs/flux
import math
import torch
from einops import rearrange, repeat
from torch import nn
from dataclasses import dataclass
from backend.attention import attention_function
def attention(q, k, v, pe):
q, k = apply_rope(q, k, pe)
x = attention_function(q, k, v, q.shape[1], skip_reshape=True)
return x
def rope(pos, dim, theta):
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta ** scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
def apply_rope(xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def timestep_embedding(t, dim, max_period=10000, time_factor=1000.0):
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
@dataclass
class ModulationOut:
shift: torch.Tensor
scale: torch.Tensor
gate: torch.Tensor
@dataclass
class FluxParams:
in_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool
class EmbedND(nn.Module):
def __init__(self, dim, theta, axes_dim):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids):
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
class MLPEmbedder(nn.Module):
def __init__(self, in_dim, hidden_dim):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def forward(self, x):
return self.out_layer(self.silu(self.in_layer(x)))
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x):
to_args = dict(device=x.device, dtype=x.dtype)
x = x.float()
rrms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms.to(x) * self.scale.to(x)).to(**to_args)
class QKNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.query_norm = RMSNorm(dim)
self.key_norm = RMSNorm(dim)
def forward(self, q, k, v):
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def forward(self, x, pe):
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
class Modulation(nn.Module):
def __init__(self, dim, double):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def forward(self, vec):
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return (
ModulationOut(*out[:3]),
ModulationOut(*out[3:]) if self.is_double else None,
)
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio, qkv_bias=False):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
def forward(self, img, txt, vec, pe):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
return img, txt
class SingleStreamBlock(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, qk_scale=None):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False)
def forward(self, x, vec, pe):
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
attn = attention(q, k, v, pe=pe)
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output
class LastLayer(nn.Module):
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x, vec):
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x
class IntegratedFluxTransformer2DModel(nn.Module):
def __init__(self, **kwargs):
super().__init__()
params = FluxParams(**kwargs)
self.params = params
self.in_channels = params.in_channels * 4
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def inner_forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None):
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1]:, ...]
img = self.final_layer(img, vec)
return img
def forward(self, x, timestep, context, y, guidance, **kwargs):
bs, c, h, w = x.shape
input_device = x.device
input_dtype = x.dtype
patch_size = 2
pad_h = (patch_size - x.shape[-2] % patch_size) % patch_size
pad_w = (patch_size - x.shape[-1] % patch_size) % patch_size
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="circular")
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=input_device, dtype=input_dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=input_device, dtype=input_dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=input_device, dtype=input_dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=input_device, dtype=input_dtype)
out = self.inner_forward(img, img_ids, context, txt_ids, timestep, y, guidance)
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w]
return out

212
backend/nn/t5.py Normal file
View File

@@ -0,0 +1,212 @@
import torch
import math
from backend.attention import attention_function
activations = {
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
"relu": torch.nn.functional.relu,
}
class T5LayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
return self.weight.to(x) * x
class T5DenseActDense(torch.nn.Module):
def __init__(self, model_dim, ff_dim, ff_activation):
super().__init__()
self.wi = torch.nn.Linear(model_dim, ff_dim, bias=False)
self.wo = torch.nn.Linear(ff_dim, model_dim, bias=False)
self.act = activations[ff_activation]
def forward(self, x):
x = self.act(self.wi(x))
x = self.wo(x)
return x
class T5DenseGatedActDense(torch.nn.Module):
def __init__(self, model_dim, ff_dim, ff_activation):
super().__init__()
self.wi_0 = torch.nn.Linear(model_dim, ff_dim, bias=False)
self.wi_1 = torch.nn.Linear(model_dim, ff_dim, bias=False)
self.wo = torch.nn.Linear(ff_dim, model_dim, bias=False)
self.act = activations[ff_activation]
def forward(self, x):
hidden_gelu = self.act(self.wi_0(x))
hidden_linear = self.wi_1(x)
x = hidden_gelu * hidden_linear
x = self.wo(x)
return x
class T5LayerFF(torch.nn.Module):
def __init__(self, model_dim, ff_dim, ff_activation, gated_act):
super().__init__()
if gated_act:
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, ff_activation)
else:
self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, ff_activation)
self.layer_norm = T5LayerNorm(model_dim)
def forward(self, x):
forwarded_states = self.layer_norm(x)
forwarded_states = self.DenseReluDense(forwarded_states)
x += forwarded_states
return x
class T5Attention(torch.nn.Module):
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias):
super().__init__()
self.q = torch.nn.Linear(model_dim, inner_dim, bias=False)
self.k = torch.nn.Linear(model_dim, inner_dim, bias=False)
self.v = torch.nn.Linear(model_dim, inner_dim, bias=False)
self.o = torch.nn.Linear(inner_dim, model_dim, bias=False)
self.num_heads = num_heads
self.relative_attention_bias = None
if relative_attention_bias:
self.relative_attention_num_buckets = 32
self.relative_attention_max_distance = 128
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads)
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
max_exact = num_buckets // 2
is_small = relative_position < max_exact
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length, device, dtype):
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position
relative_position_bucket = self._relative_position_bucket(
relative_position,
bidirectional=True,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket).to(dtype)
values = values.permute([2, 0, 1]).unsqueeze(0)
return values
def forward(self, x, mask=None, past_bias=None):
q = self.q(x)
k = self.k(x)
v = self.v(x)
if self.relative_attention_bias is not None:
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device, x.dtype)
if past_bias is not None:
if mask is not None:
mask = mask + past_bias
else:
mask = past_bias
out = attention_function(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask)
return self.o(out), past_bias
class T5LayerSelfAttention(torch.nn.Module):
def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias):
super().__init__()
self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias)
self.layer_norm = T5LayerNorm(model_dim)
def forward(self, x, mask=None, past_bias=None):
output, past_bias = self.SelfAttention(self.layer_norm(x), mask=mask, past_bias=past_bias)
x += output
return x, past_bias
class T5Block(torch.nn.Module):
def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias):
super().__init__()
self.layer = torch.nn.ModuleList()
self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias))
self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, gated_act))
def forward(self, x, mask=None, past_bias=None):
x, past_bias = self.layer[0](x, mask, past_bias)
x = self.layer[-1](x)
return x, past_bias
class T5Stack(torch.nn.Module):
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention):
super().__init__()
self.block = torch.nn.ModuleList(
[T5Block(model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias=((not relative_attention) or (i == 0))) for i in range(num_layers)]
)
self.final_layer_norm = T5LayerNorm(model_dim)
def forward(self, x, attention_mask=None):
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
past_bias = None
for i, l in enumerate(self.block):
x, past_bias = l(x, mask, past_bias)
x = self.final_layer_norm(x)
return x
class T5(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.num_layers = config["num_layers"]
model_dim = config["d_model"]
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config["d_ff"], config["dense_act_fn"], config["is_gated_act"], config["num_heads"], config["model_type"] != "umt5")
self.shared = torch.nn.Embedding(config["vocab_size"], model_dim)
def forward(self, input_ids, *args, **kwargs):
x = self.shared(input_ids)
x = torch.nan_to_num(x)
return self.encoder(x, *args, **kwargs)
class IntegratedT5(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.transformer = T5(config)
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))

View File

@@ -22,14 +22,16 @@ def weights_manual_cast(layer, x, skip_dtype=False):
if stream.using_stream:
with stream.stream_context()(stream.mover_stream):
if layer.weight is not None:
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
if layer.bias is not None:
bias = layer.bias.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
signal = stream.mover_stream.record_event()
else:
if layer.weight is not None:
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
if layer.bias is not None:
bias = layer.bias.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
return weight, bias, signal

View File

@@ -110,6 +110,9 @@ def compile_conditions(cond):
)
)
if 'guidance' in cond:
result['model_conds']['guidance'] = Condition(cond['guidance'])
return [result, ]

View File

@@ -890,7 +890,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C)
latent_channels = shared.sd_model.forge_objects.vae.latent_channels
p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
if p.scripts is not None:

View File

@@ -318,17 +318,19 @@ def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_s
def stack_conds(tensors):
# if prompts have wildly different lengths above the limit we'll get tensors of different shapes
# and won't be able to torch.stack them. So this fixes that.
token_count = max([x.shape[0] for x in tensors])
for i in range(len(tensors)):
if tensors[i].shape[0] != token_count:
last_vector = tensors[i][-1:]
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
return torch.stack(tensors)
try:
result = torch.stack(tensors)
except:
# if prompts have wildly different lengths above the limit we'll get tensors of different shapes
# and won't be able to torch.stack them. So this fixes that.
token_count = max([x.shape[0] for x in tensors])
for i in range(len(tensors)):
if tensors[i].shape[0] != token_count:
last_vector = tensors[i][-1:]
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
result = torch.stack(tensors)
return result
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):

View File

@@ -58,7 +58,7 @@ def model():
model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)
loaded_model = VAEApprox(latent_channels=shared.sd_model.latent_channels)
loaded_model = VAEApprox(latent_channels=shared.sd_model.forge_objects.vae.latent_channels)
loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
loaded_model.eval()
loaded_model.to(devices.device, devices.dtype)