Added initial support for f-lite model

This commit is contained in:
Jaret Burkett
2025-05-01 11:15:18 -06:00
parent 5890e67a46
commit d9700bdb99
9 changed files with 1076 additions and 4 deletions

View File

@@ -1,7 +1,8 @@
from .chroma import ChromaModel
from .hidream import HidreamModel
from .f_light import FLiteModel
AI_TOOLKIT_MODELS = [
# put a list of models here
ChromaModel, HidreamModel
ChromaModel, HidreamModel, FLiteModel
]

View File

@@ -0,0 +1 @@
from .f_light import FLiteModel

View File

@@ -0,0 +1,295 @@
import os
from typing import TYPE_CHECKING
import torch
import yaml
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from PIL import Image
from toolkit.models.base_model import BaseModel
from toolkit.basic import flush
from diffusers import AutoencoderKL
from toolkit.prompt_utils import PromptEmbeds
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
from toolkit.dequantize import patch_dequantization_on_save
from toolkit.accelerator import unwrap_model
from optimum.quanto import freeze, QTensor
from toolkit.util.quantize import quantize, get_qtype
from transformers import T5TokenizerFast, T5EncoderModel
from .src import FLitePipeline, DiT
if TYPE_CHECKING:
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
scheduler_config = {
"base_image_seq_len": 256,
"base_shift": 0.5,
"max_image_seq_len": 4096,
"max_shift": 1.15,
"num_train_timesteps": 1000,
"shift": 3.0,
"use_dynamic_shifting": True
}
class FLiteModel(BaseModel):
arch = "f-lite"
def __init__(
self,
device,
model_config: ModelConfig,
dtype='bf16',
custom_pipeline=None,
noise_scheduler=None,
**kwargs
):
super().__init__(
device,
model_config,
dtype,
custom_pipeline,
noise_scheduler,
**kwargs
)
self.is_flow_matching = True
self.is_transformer = True
self.target_lora_modules = ['DiT']
# static method to get the noise scheduler
@staticmethod
def get_train_scheduler():
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
def get_bucket_divisibility(self):
# return the bucket divisibility for the model
return 16
def load_model(self):
dtype = self.torch_dtype
# will be updated if we detect a existing checkpoint in training folder
model_path = self.model_config.name_or_path
extras_path = self.model_config.extras_name_or_path
self.print_and_status_update("Loading transformer")
transformer = DiT.from_pretrained(
model_path,
subfolder="dit_model",
torch_dtype=dtype,
)
transformer.to(self.quantize_device, dtype=dtype)
if self.model_config.quantize:
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = get_qtype(self.model_config.qtype)
self.print_and_status_update("Quantizing transformer")
quantize(transformer, weights=quantization_type,
**self.model_config.quantize_kwargs)
freeze(transformer)
transformer.to(self.device_torch)
else:
transformer.to(self.device_torch, dtype=dtype)
flush()
self.print_and_status_update("Loading T5")
tokenizer = T5TokenizerFast.from_pretrained(
extras_path, subfolder="tokenizer", torch_dtype=dtype
)
text_encoder = T5EncoderModel.from_pretrained(
extras_path, subfolder="text_encoder", torch_dtype=dtype
)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing T5")
quantize(text_encoder, weights=get_qtype(
self.model_config.qtype))
freeze(text_encoder)
flush()
self.noise_scheduler = FLiteModel.get_train_scheduler()
self.print_and_status_update("Loading VAE")
vae = AutoencoderKL.from_pretrained(
extras_path,
subfolder="vae",
torch_dtype=dtype
)
vae = vae.to(self.device_torch, dtype=dtype)
self.print_and_status_update("Making pipe")
pipe: FLitePipeline = FLitePipeline(
text_encoder=None,
tokenizer=tokenizer,
vae=vae,
dit_model=None,
)
# for quantization, it works best to do these after making the pipe
pipe.text_encoder = text_encoder
pipe.dit_model = transformer
pipe.transformer = transformer
pipe.scheduler = self.noise_scheduler,
self.print_and_status_update("Preparing Model")
text_encoder = [pipe.text_encoder]
tokenizer = [pipe.tokenizer]
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
# just to make sure everything is on the right device and dtype
text_encoder[0].to(self.device_torch)
text_encoder[0].requires_grad_(False)
text_encoder[0].eval()
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
# save it to the model class
self.vae = vae
self.text_encoder = text_encoder # list of text encoders
self.tokenizer = tokenizer # list of tokenizers
self.model = pipe.transformer
self.pipeline = pipe
self.print_and_status_update("Model Loaded")
def get_generation_pipeline(self):
scheduler = FLiteModel.get_train_scheduler()
# it has built in scheduler. Basically euler flowmatching
pipeline = FLitePipeline(
text_encoder=unwrap_model(self.text_encoder[0]),
tokenizer=self.tokenizer[0],
vae=unwrap_model(self.vae),
dit_model=unwrap_model(self.transformer)
)
pipeline.transformer = pipeline.dit_model
pipeline.scheduler = scheduler
return pipeline
def generate_single_image(
self,
pipeline: FLitePipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
extra['negative_prompt_embeds'] = unconditional_embeds.text_embeds
img = pipeline(
prompt_embeds=conditional_embeds.text_embeds,
negative_prompt_embeds=unconditional_embeds.text_embeds,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
generator=generator,
).images[0]
return img
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
**kwargs
):
cast_dtype = self.unet.dtype
noise_pred = self.unet(
latent_model_input.to(
self.device_torch, cast_dtype
),
text_embeddings.text_embeds.to(
self.device_torch, cast_dtype
),
timestep / 1000,
)
if isinstance(noise_pred, QTensor):
noise_pred = noise_pred.dequantize()
return noise_pred
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
if isinstance(prompt, str):
prompts = [prompt]
else:
prompts = prompt
if self.pipeline.text_encoder.device != self.device_torch:
self.pipeline.text_encoder.to(self.device_torch)
prompt_embeds, negative_embeds = self.pipeline.encode_prompt(
prompt=prompts,
negative_prompt=None,
device=self.text_encoder[0].device,
dtype=self.torch_dtype,
)
pe = PromptEmbeds(prompt_embeds)
return pe
def get_model_has_grad(self):
# return from a weight if it has grad
return False
def get_te_has_grad(self):
# return from a weight if it has grad
return False
def save_model(self, output_path, meta, save_dtype):
# only save the unet
transformer: DiT = unwrap_model(self.model)
# diffusers
# only save the unet
transformer: DiT = unwrap_model(self.transformer)
transformer.save_pretrained(
save_directory=os.path.join(output_path, 'dit_model'),
safe_serialization=True,
)
# save out meta config
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
with open(meta_path, 'w') as f:
yaml.dump(meta, f)
def get_loss_target(self, *args, **kwargs):
noise = kwargs.get('noise')
batch = kwargs.get('batch')
# return (noise - batch.latents).detach()
return (batch.latents - noise).detach()
def convert_lora_weights_before_save(self, state_dict):
# currently starte with transformer. but needs to start with diffusion_model. for comfyui
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("transformer.", "diffusion_model.")
new_sd[new_key] = value
return new_sd
def convert_lora_weights_before_load(self, state_dict):
# saved as diffusion_model. but needs to be transformer. for ai-toolkit
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("diffusion_model.", "transformer.")
new_sd[new_key] = value
return new_sd
def get_base_model_version(self):
return "f-lite"
def get_stepped_pred(self, pred, noise):
# just used for DFE support
latents = pred + noise
return latents

View File

@@ -0,0 +1,5 @@
from .pipeline import FLitePipeline, FLitePipelineOutput, APGConfig
from .model import DiT
__all__ = ["FLitePipeline", "FLitePipelineOutput", "APGConfig", "DiT"]

View File

@@ -0,0 +1,456 @@
# originally from https://github.com/fal-ai/f-lite/blob/main/f_lite/model.py but modified slightly
import math
import torch
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils.accelerate_utils import apply_forward_hook
from einops import rearrange
from peft import get_peft_model_state_dict, set_peft_model_state_dict
from torch import nn
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=t.device
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
return embedding
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6, trainable=False):
super().__init__()
self.eps = eps
if trainable:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.weight = None
def forward(self, x):
x_dtype = x.dtype
x = x.float()
norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
if self.weight is not None:
return (x * norm * self.weight).to(dtype=x_dtype)
else:
return (x * norm).to(dtype=x_dtype)
class QKNorm(nn.Module):
"""Normalizing the query and the key independently, as Flux proposes"""
def __init__(self, dim, trainable=False):
super().__init__()
self.query_norm = RMSNorm(dim, trainable=trainable)
self.key_norm = RMSNorm(dim, trainable=trainable)
def forward(self, q, k):
q = self.query_norm(q)
k = self.key_norm(k)
return q, k
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
is_self_attn=True,
cross_attn_input_size=None,
residual_v=False,
dynamic_softmax_temperature=False,
):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.is_self_attn = is_self_attn
self.residual_v = residual_v
self.dynamic_softmax_temperature = dynamic_softmax_temperature
if is_self_attn:
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
else:
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.context_kv = nn.Linear(cross_attn_input_size, dim * 2, bias=qkv_bias)
self.proj = nn.Linear(dim, dim, bias=False)
if residual_v:
self.lambda_param = nn.Parameter(torch.tensor(0.5).reshape(1))
self.qk_norm = QKNorm(self.head_dim)
def forward(self, x, context=None, v_0=None, rope=None):
if self.is_self_attn:
qkv = self.qkv(x)
qkv = rearrange(qkv, "b l (k h d) -> k b h l d", k=3, h=self.num_heads)
q, k, v = qkv.unbind(0)
if self.residual_v and v_0 is not None:
v = self.lambda_param * v + (1 - self.lambda_param) * v_0
if rope is not None:
# print(q.shape, rope[0].shape, rope[1].shape)
q = apply_rotary_emb(q, rope[0], rope[1])
k = apply_rotary_emb(k, rope[0], rope[1])
# https://arxiv.org/abs/2306.08645
# https://arxiv.org/abs/2410.01104
# ratioonale is that if tokens get larger, categorical distribution get more uniform
# so you want to enlargen entropy.
token_length = q.shape[2]
if self.dynamic_softmax_temperature:
ratio = math.sqrt(math.log(token_length) / math.log(1040.0)) # 1024 + 16
k = k * ratio
q, k = self.qk_norm(q, k)
else:
q = rearrange(self.q(x), "b l (h d) -> b h l d", h=self.num_heads)
kv = rearrange(
self.context_kv(context),
"b l (k h d) -> k b h l d",
k=2,
h=self.num_heads,
)
k, v = kv.unbind(0)
q, k = self.qk_norm(q, k)
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b h l d -> b l (h d)")
x = self.proj(x)
return x, v if self.is_self_attn else None
class DiTBlock(nn.Module):
def __init__(
self,
hidden_size,
cross_attn_input_size,
num_heads,
mlp_ratio=4.0,
qkv_bias=True,
residual_v=False,
dynamic_softmax_temperature=False,
):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = RMSNorm(hidden_size, trainable=qkv_bias)
self.self_attn = Attention(
hidden_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
is_self_attn=True,
residual_v=residual_v,
dynamic_softmax_temperature=dynamic_softmax_temperature,
)
if cross_attn_input_size is not None:
self.norm2 = RMSNorm(hidden_size, trainable=qkv_bias)
self.cross_attn = Attention(
hidden_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
is_self_attn=False,
cross_attn_input_size=cross_attn_input_size,
dynamic_softmax_temperature=dynamic_softmax_temperature,
)
else:
self.norm2 = None
self.cross_attn = None
self.norm3 = RMSNorm(hidden_size, trainable=qkv_bias)
mlp_hidden = int(hidden_size * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden),
nn.GELU(),
nn.Linear(mlp_hidden, hidden_size),
)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 9 * hidden_size, bias=True))
self.adaLN_modulation[-1].weight.data.zero_()
self.adaLN_modulation[-1].bias.data.zero_()
# @torch.compile(mode='reduce-overhead')
def forward(self, x, context, c, v_0=None, rope=None):
(
shift_sa,
scale_sa,
gate_sa,
shift_ca,
scale_ca,
gate_ca,
shift_mlp,
scale_mlp,
gate_mlp,
) = self.adaLN_modulation(c).chunk(9, dim=1)
scale_sa = scale_sa[:, None, :]
scale_ca = scale_ca[:, None, :]
scale_mlp = scale_mlp[:, None, :]
shift_sa = shift_sa[:, None, :]
shift_ca = shift_ca[:, None, :]
shift_mlp = shift_mlp[:, None, :]
gate_sa = gate_sa[:, None, :]
gate_ca = gate_ca[:, None, :]
gate_mlp = gate_mlp[:, None, :]
norm_x = self.norm1(x.clone())
norm_x = norm_x * (1 + scale_sa) + shift_sa
attn_out, v = self.self_attn(norm_x, v_0=v_0, rope=rope)
x = x + attn_out * gate_sa
if self.norm2 is not None:
norm_x = self.norm2(x)
norm_x = norm_x * (1 + scale_ca) + shift_ca
x = x + self.cross_attn(norm_x, context)[0] * gate_ca
norm_x = self.norm3(x)
norm_x = norm_x * (1 + scale_mlp) + shift_mlp
x = x + self.mlp(norm_x) * gate_mlp
return x, v
class PatchEmbed(nn.Module):
def __init__(self, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.patch_proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
self.patch_size = patch_size
def forward(self, x):
B, C, H, W = x.shape
x = self.patch_proj(x)
x = rearrange(x, "b c h w -> b (h w) c")
return x
class TwoDimRotary(torch.nn.Module):
def __init__(self, dim, base=10000, h=256, w=256):
super().__init__()
self.inv_freq = torch.FloatTensor([1.0 / (base ** (i / dim)) for i in range(0, dim, 2)])
self.h = h
self.w = w
t_h = torch.arange(h, dtype=torch.float32)
t_w = torch.arange(w, dtype=torch.float32)
freqs_h = torch.outer(t_h, self.inv_freq).unsqueeze(1) # h, 1, d / 2
freqs_w = torch.outer(t_w, self.inv_freq).unsqueeze(0) # 1, w, d / 2
freqs_h = freqs_h.repeat(1, w, 1) # h, w, d / 2
freqs_w = freqs_w.repeat(h, 1, 1) # h, w, d / 2
freqs_hw = torch.cat([freqs_h, freqs_w], 2) # h, w, d
self.register_buffer("freqs_hw_cos", freqs_hw.cos())
self.register_buffer("freqs_hw_sin", freqs_hw.sin())
def forward(self, x, height_width=None, extend_with_register_tokens=0):
if height_width is not None:
this_h, this_w = height_width
else:
this_hw = x.shape[1]
this_h, this_w = int(this_hw**0.5), int(this_hw**0.5)
cos = self.freqs_hw_cos[0 : this_h, 0 : this_w]
sin = self.freqs_hw_sin[0 : this_h, 0 : this_w]
cos = cos.clone().reshape(this_h * this_w, -1)
sin = sin.clone().reshape(this_h * this_w, -1)
# append N of zero-attn tokens
if extend_with_register_tokens > 0:
cos = torch.cat(
[
torch.ones(extend_with_register_tokens, cos.shape[1]).to(cos.device),
cos,
],
0,
)
sin = torch.cat(
[
torch.zeros(extend_with_register_tokens, sin.shape[1]).to(sin.device),
sin,
],
0,
)
return cos[None, None, :, :], sin[None, None, :, :] # [1, 1, T + N, Attn-dim]
def apply_rotary_emb(x, cos, sin):
orig_dtype = x.dtype
x = x.to(dtype=torch.float32)
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
x1 = x[..., :d]
x2 = x[..., d:]
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat([y1, y2], 3).to(dtype=orig_dtype)
class DiT(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): # type: ignore[misc]
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels=4,
patch_size=2,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
cross_attn_input_size=128,
residual_v=False,
train_bias_and_rms=True,
use_rope=True,
gradient_checkpoint=False,
dynamic_softmax_temperature=False,
rope_base=10000,
):
super().__init__()
self.patch_embed = PatchEmbed(patch_size, in_channels, hidden_size)
if use_rope:
self.rope = TwoDimRotary(hidden_size // (2 * num_heads), base=rope_base, h=512, w=512)
else:
self.positional_embedding = nn.Parameter(torch.zeros(1, 2048, hidden_size))
self.register_tokens = nn.Parameter(torch.randn(1, 16, hidden_size))
self.time_embed = nn.Sequential(
nn.Linear(hidden_size, 4 * hidden_size),
nn.SiLU(),
nn.Linear(4 * hidden_size, hidden_size),
)
self.blocks = nn.ModuleList(
[
DiTBlock(
hidden_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
cross_attn_input_size=cross_attn_input_size,
residual_v=residual_v,
qkv_bias=train_bias_and_rms,
dynamic_softmax_temperature=dynamic_softmax_temperature,
)
for _ in range(depth)
]
)
self.final_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
self.final_norm = RMSNorm(hidden_size, trainable=train_bias_and_rms)
self.final_proj = nn.Linear(hidden_size, patch_size * patch_size * in_channels)
nn.init.zeros_(self.final_modulation[-1].weight)
nn.init.zeros_(self.final_modulation[-1].bias)
nn.init.zeros_(self.final_proj.weight)
nn.init.zeros_(self.final_proj.bias)
self.paramstatus = {}
for n, p in self.named_parameters():
self.paramstatus[n] = {
"shape": p.shape,
"requires_grad": p.requires_grad,
}
self.gradient_checkpointing = False
def save_lora_weights(self, save_directory):
"""Save LoRA weights to a file"""
lora_state_dict = get_peft_model_state_dict(self)
torch.save(lora_state_dict, f"{save_directory}/lora_weights.pt")
def load_lora_weights(self, load_directory):
"""Load LoRA weights from a file"""
lora_state_dict = torch.load(f"{load_directory}/lora_weights.pt")
set_peft_model_state_dict(self, lora_state_dict)
@apply_forward_hook
def forward(self, x, context, timesteps):
b, c, h, w = x.shape
x = self.patch_embed(x) # b, T, d
x = torch.cat([self.register_tokens.repeat(b, 1, 1), x], 1) # b, T + N, d
if self.config.use_rope:
cos, sin = self.rope(
x,
extend_with_register_tokens=16,
height_width=(h // self.config.patch_size, w // self.config.patch_size),
)
else:
x = x + self.positional_embedding.repeat(b, 1, 1)[:, : x.shape[1], :]
cos, sin = None, None
t_emb = timestep_embedding(timesteps * 1000, self.config.hidden_size).to(x.device, dtype=x.dtype)
t_emb = self.time_embed(t_emb)
v_0 = None
for _idx, block in enumerate(self.blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
x, v = self._gradient_checkpointing_func(
block,
x,
context,
t_emb,
v_0,
(cos, sin)
)
else:
x, v = block(x, context, t_emb, v_0, (cos, sin))
if v_0 is None:
v_0 = v
x = x[:, 16:, :]
final_shift, final_scale = self.final_modulation(t_emb).chunk(2, dim=1)
x = self.final_norm(x)
x = x * (1 + final_scale[:, None, :]) + final_shift[:, None, :]
x = self.final_proj(x)
x = rearrange(
x,
"b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
h=h // self.config.patch_size,
w=w // self.config.patch_size,
p1=self.config.patch_size,
p2=self.config.patch_size,
)
return x
if __name__ == "__main__":
model = DiT(
in_channels=4,
patch_size=2,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
cross_attn_input_size=128,
residual_v=False,
train_bias_and_rms=True,
use_rope=True,
).cuda()
print(
model(
torch.randn(1, 4, 64, 64).cuda(),
torch.randn(1, 37, 128).cuda(),
torch.tensor([1.0]).cuda(),
)
)

View File

@@ -0,0 +1,308 @@
# originally from https://github.com/fal-ai/f-lite/blob/main/f_lite/pipeline.py but modified slightly
import logging
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers import AutoencoderKL, DiffusionPipeline
from diffusers.utils import BaseOutput
from diffusers.utils.torch_utils import randn_tensor
from PIL import Image
from torch import FloatTensor
from tqdm.auto import tqdm
from transformers import T5EncoderModel, T5TokenizerFast
logger = logging.getLogger(__name__)
@dataclass
class APGConfig:
"""APG (Augmented Parallel Guidance) configuration"""
enabled: bool = True
orthogonal_threshold: float = 0.03
@dataclass
class FLitePipelineOutput(BaseOutput):
"""
Output class for FLitePipeline pipeline.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[Image.Image], np.ndarray]
class FLitePipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using F-Lite model.
This model inherits from [`DiffusionPipeline`].
"""
model_cpu_offload_seq = "text_encoder->dit_model->vae"
dit_model: torch.nn.Module
vae: AutoencoderKL
text_encoder: T5EncoderModel
tokenizer: T5TokenizerFast
_progress_bar_config: Dict[str, Any]
def __init__(
self, dit_model: torch.nn.Module, vae: AutoencoderKL, text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast
):
super().__init__()
# Register all modules for the pipeline
# Access DiffusionPipeline's register_modules directly to avoid mypy error
DiffusionPipeline.register_modules(
self, dit_model=dit_model, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer
)
# Move models to channels last for better performance
# AutoencoderKL inherits from torch.nn.Module which has these methods
if hasattr(self.vae, "to"):
self.vae.to(memory_format=torch.channels_last)
if hasattr(self.vae, "requires_grad_"):
self.vae.requires_grad_(False)
if hasattr(self.text_encoder, "requires_grad_"):
self.text_encoder.requires_grad_(False)
# Constants
self.vae_scale_factor = 8
self.return_index = -8 # T5 hidden state index to use
def enable_vae_slicing(self):
"""Enable VAE slicing for memory efficiency."""
if hasattr(self.vae, "enable_slicing"):
self.vae.enable_slicing()
def enable_vae_tiling(self):
"""Enable VAE tiling for memory efficiency."""
if hasattr(self.vae, "enable_tiling"):
self.vae.enable_tiling()
def set_progress_bar_config(self, **kwargs):
"""Set progress bar configuration."""
self._progress_bar_config = kwargs
def progress_bar(self, iterable=None, **kwargs):
"""Create progress bar for iterations."""
self._progress_bar_config = getattr(self, "_progress_bar_config", None) or {}
config = {**self._progress_bar_config, **kwargs}
return tqdm(iterable, **config)
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 512,
return_index: int = -8,
) -> Tuple[FloatTensor, FloatTensor]:
"""Encodes the prompt and negative prompt."""
if isinstance(prompt, str):
prompt = [prompt]
device = device or self.text_encoder.device
# Text encoder forward pass
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_embeds = self.text_encoder(text_input_ids, return_dict=True, output_hidden_states=True)
prompt_embeds_tensor = prompt_embeds.hidden_states[return_index]
if return_index != -1:
prompt_embeds_tensor = self.text_encoder.encoder.final_layer_norm(prompt_embeds_tensor)
prompt_embeds_tensor = self.text_encoder.encoder.dropout(prompt_embeds_tensor)
dtype = dtype or next(self.text_encoder.parameters()).dtype
prompt_embeds_tensor = prompt_embeds_tensor.to(dtype=dtype, device=device)
# Handle negative prompts
if negative_prompt is None:
negative_embeds = torch.zeros_like(prompt_embeds_tensor)
else:
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
negative_result = self.encode_prompt(
prompt=negative_prompt, device=device, dtype=dtype, return_index=return_index
)
negative_embeds = negative_result[0]
# Explicitly cast both tensors to FloatTensor for mypy
from typing import cast
prompt_tensor = cast(FloatTensor, prompt_embeds_tensor.to(dtype=dtype))
negative_tensor = cast(FloatTensor, negative_embeds.to(dtype=dtype))
return (prompt_tensor, negative_tensor)
def to(self, torch_device=None, torch_dtype=None, silence_dtype_warnings=False):
"""Move pipeline components to specified device and dtype."""
if hasattr(self, "vae"):
self.vae.to(device=torch_device, dtype=torch_dtype)
if hasattr(self, "text_encoder"):
self.text_encoder.to(device=torch_device, dtype=torch_dtype)
if hasattr(self, "dit_model"):
self.dit_model.to(device=torch_device, dtype=torch_dtype)
return self
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]]=None,
prompt_embeds: Optional[FloatTensor] = None,
height: Optional[int] = 1024,
width: Optional[int] = 1024,
num_inference_steps: int = 30,
guidance_scale: float = 6.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_embeds: Optional[FloatTensor] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
dtype: Optional[torch.dtype] = None,
alpha: Optional[float] = None,
apg_config: Optional[APGConfig] = None,
**kwargs,
):
"""Generate images from text prompt."""
# Ensure height and width are not None for calculation
if height is None:
height = 1024
if width is None:
width = 1024
dtype = dtype or next(self.dit_model.parameters()).dtype
apg_config = apg_config or APGConfig(enabled=False)
device = self._execution_device
# 2. Encode prompts
prompt_batch_size = len(prompt) if isinstance(prompt, list) else 1
batch_size = prompt_batch_size * num_images_per_prompt
if prompt_embeds is None or negative_prompt_embeds is None:
prompt_embeds, negative_embeds = self.encode_prompt(
prompt=prompt, negative_prompt=negative_prompt, device=self.text_encoder.device, dtype=dtype,
return_index=self.return_index,
)
else:
negative_embeds = negative_prompt_embeds
# Repeat embeddings for num_images_per_prompt
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_embeds = negative_embeds.repeat_interleave(num_images_per_prompt, dim=0)
# 3. Initialize latents
latent_height = height // self.vae_scale_factor
latent_width = width // self.vae_scale_factor
if isinstance(generator, list):
if len(generator) != batch_size:
raise ValueError(f"Got {len(generator)} generators for {batch_size} samples")
latents = randn_tensor((batch_size, 16, latent_height, latent_width), generator=generator, device=device, dtype=dtype)
acc_latents = latents.clone()
# 4. Calculate alpha if not provided
if alpha is None:
image_token_size = latent_height * latent_width
alpha = 2 * math.sqrt(image_token_size / (64 * 64))
# 6. Sampling loop
self.dit_model.eval()
# Check if guidance is needed
do_classifier_free_guidance = guidance_scale >= 1.0
for i in self.progress_bar(range(num_inference_steps, 0, -1)):
# Calculate timesteps
t = i / num_inference_steps
t_next = (i - 1) / num_inference_steps
# Scale timesteps according to alpha
t = t * alpha / (1 + (alpha - 1) * t)
t_next = t_next * alpha / (1 + (alpha - 1) * t_next)
dt = t - t_next
# Create tensor with proper device
t_tensor = torch.tensor([t] * batch_size, device=device, dtype=dtype)
if do_classifier_free_guidance:
# Duplicate latents for both conditional and unconditional inputs
latents_input = torch.cat([latents] * 2)
# Concatenate negative and positive prompt embeddings
context_input = torch.cat([negative_embeds, prompt_embeds])
# Duplicate timesteps for the batch
t_input = torch.cat([t_tensor] * 2)
# Get model predictions in a single pass
model_outputs = self.dit_model(latents_input, context_input, t_input)
# Split outputs back into unconditional and conditional predictions
uncond_output, cond_output = model_outputs.chunk(2)
if apg_config.enabled:
# Augmented Parallel Guidance
dy = cond_output
dd = cond_output - uncond_output
# Find parallel direction
parallel_direction = (dy * dd).sum() / (dy * dy).sum() * dy
orthogonal_direction = dd - parallel_direction
# Scale orthogonal component
orthogonal_std = orthogonal_direction.std()
orthogonal_scale = min(1, apg_config.orthogonal_threshold / orthogonal_std)
orthogonal_direction = orthogonal_direction * orthogonal_scale
model_output = dy + (guidance_scale - 1) * orthogonal_direction
else:
# Standard classifier-free guidance
model_output = uncond_output + guidance_scale * (cond_output - uncond_output)
else:
# If no guidance needed, just run the model normally
model_output = self.dit_model(latents, prompt_embeds, t_tensor)
# Update latents
acc_latents = acc_latents + dt * model_output.to(device)
latents = acc_latents.clone()
# 7. Decode latents
# These checks handle the case where mypy doesn't recognize these attributes
scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) if hasattr(self.vae, "config") else 0.18215
shift_factor = getattr(self.vae.config, "shift_factor", 0) if hasattr(self.vae, "config") else 0
latents = latents / scaling_factor + shift_factor
vae_dtype = self.vae.dtype if hasattr(self.vae, "dtype") else dtype
decoded_images = self.vae.decode(latents.to(vae_dtype)).sample if hasattr(self.vae, "decode") else latents
# Offload all models
try:
self.maybe_free_model_hooks()
except AttributeError as e:
if "OptimizedModule" in str(e):
import warnings
warnings.warn(
"Encountered 'OptimizedModule' error when offloading models. "
"This issue might be fixed in the future by: "
"https://github.com/huggingface/diffusers/pull/10730"
)
else:
raise
# 8. Post-process images
images = (decoded_images / 2 + 0.5).clamp(0, 1)
# Convert to PIL Images
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu()
pil_images = [Image.fromarray(img.permute(1, 2, 0).numpy()) for img in images]
return FLitePipelineOutput(
images=pil_images,
)

View File

@@ -249,7 +249,8 @@ class DiffusionFeatureExtractor3(nn.Module):
# lpips_weight=1.0,
lpips_weight=10.0,
clip_weight=0.1,
pixel_weight=0.1
pixel_weight=0.1,
model=None
):
dtype = torch.bfloat16
device = self.vae.device
@@ -274,7 +275,10 @@ class DiffusionFeatureExtractor3(nn.Module):
# stepped_latents = torch.cat(stepped_chunks, dim=0)
stepped_latents = noise - noise_pred
if model is not None and hasattr(model, 'get_stepped_pred'):
stepped_latents = model.get_stepped_pred(noise_pred, noise)
else:
stepped_latents = noise - noise_pred
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)

View File

@@ -2283,6 +2283,7 @@ class StableDiffusion:
bleed_latents: torch.FloatTensor = None,
is_input_scaled=False,
return_first_prediction=False,
bypass_guidance_embedding=False,
**kwargs,
):
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
@@ -2299,6 +2300,7 @@ class StableDiffusion:
add_time_ids=add_time_ids,
is_input_scaled=is_input_scaled,
return_conditional_pred=True,
bypass_guidance_embedding=bypass_guidance_embedding,
**kwargs,
)
# some schedulers need to run separately, so do that. (euler for example)

View File

@@ -145,7 +145,7 @@ if TYPE_CHECKING:
def concat_prompt_embeddings(
unconditional: 'PromptEmbeds',
conditional: 'PromptEmbeds',
n_imgs: int,
n_imgs: int=0,
):
from toolkit.stable_diffusion_model import PromptEmbeds
text_embeds = torch.cat(