From d9700bdb992883531330080b237d79fa982f19ee Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 1 May 2025 11:15:18 -0600 Subject: [PATCH] Added initial support for f-lite model --- .../diffusion_models/__init__.py | 3 +- .../diffusion_models/f_light/__init__.py | 1 + .../diffusion_models/f_light/f_light.py | 295 +++++++++++ .../diffusion_models/f_light/src/__init__.py | 5 + .../diffusion_models/f_light/src/model.py | 456 ++++++++++++++++++ .../diffusion_models/f_light/src/pipeline.py | 308 ++++++++++++ .../models/diffusion_feature_extraction.py | 8 +- toolkit/stable_diffusion_model.py | 2 + toolkit/train_tools.py | 2 +- 9 files changed, 1076 insertions(+), 4 deletions(-) create mode 100644 extensions_built_in/diffusion_models/f_light/__init__.py create mode 100644 extensions_built_in/diffusion_models/f_light/f_light.py create mode 100644 extensions_built_in/diffusion_models/f_light/src/__init__.py create mode 100644 extensions_built_in/diffusion_models/f_light/src/model.py create mode 100644 extensions_built_in/diffusion_models/f_light/src/pipeline.py diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py index c775be3b..0cc323b9 100644 --- a/extensions_built_in/diffusion_models/__init__.py +++ b/extensions_built_in/diffusion_models/__init__.py @@ -1,7 +1,8 @@ from .chroma import ChromaModel from .hidream import HidreamModel +from .f_light import FLiteModel AI_TOOLKIT_MODELS = [ # put a list of models here - ChromaModel, HidreamModel + ChromaModel, HidreamModel, FLiteModel ] diff --git a/extensions_built_in/diffusion_models/f_light/__init__.py b/extensions_built_in/diffusion_models/f_light/__init__.py new file mode 100644 index 00000000..6a438f93 --- /dev/null +++ b/extensions_built_in/diffusion_models/f_light/__init__.py @@ -0,0 +1 @@ +from .f_light import FLiteModel \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/f_light/f_light.py b/extensions_built_in/diffusion_models/f_light/f_light.py new file mode 100644 index 00000000..2fc4f5a0 --- /dev/null +++ b/extensions_built_in/diffusion_models/f_light/f_light.py @@ -0,0 +1,295 @@ +import os +from typing import TYPE_CHECKING + +import torch +import yaml +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from diffusers import AutoencoderKL +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.quantize import quantize, get_qtype +from transformers import T5TokenizerFast, T5EncoderModel +from .src import FLitePipeline, DiT + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.5, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": True +} + + +class FLiteModel(BaseModel): + arch = "f-lite" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__( + device, + model_config, + dtype, + custom_pipeline, + noise_scheduler, + **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['DiT'] + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + # return the bucket divisibility for the model + return 16 + + def load_model(self): + dtype = self.torch_dtype + + # will be updated if we detect a existing checkpoint in training folder + model_path = self.model_config.name_or_path + + extras_path = self.model_config.extras_name_or_path + + self.print_and_status_update("Loading transformer") + + transformer = DiT.from_pretrained( + model_path, + subfolder="dit_model", + torch_dtype=dtype, + ) + + transformer.to(self.quantize_device, dtype=dtype) + + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = get_qtype(self.model_config.qtype) + self.print_and_status_update("Quantizing transformer") + quantize(transformer, weights=quantization_type, + **self.model_config.quantize_kwargs) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + self.print_and_status_update("Loading T5") + tokenizer = T5TokenizerFast.from_pretrained( + extras_path, subfolder="tokenizer", torch_dtype=dtype + ) + text_encoder = T5EncoderModel.from_pretrained( + extras_path, subfolder="text_encoder", torch_dtype=dtype + ) + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing T5") + quantize(text_encoder, weights=get_qtype( + self.model_config.qtype)) + freeze(text_encoder) + flush() + + self.noise_scheduler = FLiteModel.get_train_scheduler() + + self.print_and_status_update("Loading VAE") + vae = AutoencoderKL.from_pretrained( + extras_path, + subfolder="vae", + torch_dtype=dtype + ) + vae = vae.to(self.device_torch, dtype=dtype) + + self.print_and_status_update("Making pipe") + + pipe: FLitePipeline = FLitePipeline( + text_encoder=None, + tokenizer=tokenizer, + vae=vae, + dit_model=None, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder = text_encoder + pipe.dit_model = transformer + pipe.transformer = transformer + pipe.scheduler = self.noise_scheduler, + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder] + tokenizer = [pipe.tokenizer] + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = FLiteModel.get_train_scheduler() + # it has built in scheduler. Basically euler flowmatching + pipeline = FLitePipeline( + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + vae=unwrap_model(self.vae), + dit_model=unwrap_model(self.transformer) + ) + pipeline.transformer = pipeline.dit_model + pipeline.scheduler = scheduler + + return pipeline + + def generate_single_image( + self, + pipeline: FLitePipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + + extra['negative_prompt_embeds'] = unconditional_embeds.text_embeds + + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + cast_dtype = self.unet.dtype + + noise_pred = self.unet( + latent_model_input.to( + self.device_torch, cast_dtype + ), + text_embeddings.text_embeds.to( + self.device_torch, cast_dtype + ), + timestep / 1000, + ) + + if isinstance(noise_pred, QTensor): + noise_pred = noise_pred.dequantize() + + return noise_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if isinstance(prompt, str): + prompts = [prompt] + else: + prompts = prompt + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + prompt_embeds, negative_embeds = self.pipeline.encode_prompt( + prompt=prompts, + negative_prompt=None, + device=self.text_encoder[0].device, + dtype=self.torch_dtype, + ) + + pe = PromptEmbeds(prompt_embeds) + + return pe + + def get_model_has_grad(self): + # return from a weight if it has grad + return False + + def get_te_has_grad(self): + # return from a weight if it has grad + return False + + def save_model(self, output_path, meta, save_dtype): + # only save the unet + transformer: DiT = unwrap_model(self.model) + # diffusers + # only save the unet + transformer: DiT = unwrap_model(self.transformer) + transformer.save_pretrained( + save_directory=os.path.join(output_path, 'dit_model'), + safe_serialization=True, + ) + # save out meta config + meta_path = os.path.join(output_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get('noise') + batch = kwargs.get('batch') + # return (noise - batch.latents).detach() + return (batch.latents - noise).detach() + + def convert_lora_weights_before_save(self, state_dict): + # currently starte with transformer. but needs to start with diffusion_model. for comfyui + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + # saved as diffusion_model. but needs to be transformer. for ai-toolkit + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd + + def get_base_model_version(self): + return "f-lite" + + def get_stepped_pred(self, pred, noise): + # just used for DFE support + latents = pred + noise + return latents diff --git a/extensions_built_in/diffusion_models/f_light/src/__init__.py b/extensions_built_in/diffusion_models/f_light/src/__init__.py new file mode 100644 index 00000000..8e51652f --- /dev/null +++ b/extensions_built_in/diffusion_models/f_light/src/__init__.py @@ -0,0 +1,5 @@ +from .pipeline import FLitePipeline, FLitePipelineOutput, APGConfig +from .model import DiT + + +__all__ = ["FLitePipeline", "FLitePipelineOutput", "APGConfig", "DiT"] \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/f_light/src/model.py b/extensions_built_in/diffusion_models/f_light/src/model.py new file mode 100644 index 00000000..903d4928 --- /dev/null +++ b/extensions_built_in/diffusion_models/f_light/src/model.py @@ -0,0 +1,456 @@ +# originally from https://github.com/fal-ai/f-lite/blob/main/f_lite/model.py but modified slightly + +import math + +import torch +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.accelerate_utils import apply_forward_hook +from einops import rearrange +from peft import get_peft_model_state_dict, set_peft_model_state_dict +from torch import nn + + +def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=t.device + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + + return embedding + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-6, trainable=False): + super().__init__() + self.eps = eps + if trainable: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.weight = None + + def forward(self, x): + x_dtype = x.dtype + x = x.float() + norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + if self.weight is not None: + return (x * norm * self.weight).to(dtype=x_dtype) + else: + return (x * norm).to(dtype=x_dtype) + + +class QKNorm(nn.Module): + """Normalizing the query and the key independently, as Flux proposes""" + + def __init__(self, dim, trainable=False): + super().__init__() + self.query_norm = RMSNorm(dim, trainable=trainable) + self.key_norm = RMSNorm(dim, trainable=trainable) + + def forward(self, q, k): + q = self.query_norm(q) + k = self.key_norm(k) + return q, k + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + is_self_attn=True, + cross_attn_input_size=None, + residual_v=False, + dynamic_softmax_temperature=False, + ): + super().__init__() + assert dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.is_self_attn = is_self_attn + self.residual_v = residual_v + self.dynamic_softmax_temperature = dynamic_softmax_temperature + + if is_self_attn: + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + else: + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.context_kv = nn.Linear(cross_attn_input_size, dim * 2, bias=qkv_bias) + + self.proj = nn.Linear(dim, dim, bias=False) + + if residual_v: + self.lambda_param = nn.Parameter(torch.tensor(0.5).reshape(1)) + + self.qk_norm = QKNorm(self.head_dim) + + def forward(self, x, context=None, v_0=None, rope=None): + if self.is_self_attn: + qkv = self.qkv(x) + qkv = rearrange(qkv, "b l (k h d) -> k b h l d", k=3, h=self.num_heads) + q, k, v = qkv.unbind(0) + + if self.residual_v and v_0 is not None: + v = self.lambda_param * v + (1 - self.lambda_param) * v_0 + + if rope is not None: + # print(q.shape, rope[0].shape, rope[1].shape) + q = apply_rotary_emb(q, rope[0], rope[1]) + k = apply_rotary_emb(k, rope[0], rope[1]) + + # https://arxiv.org/abs/2306.08645 + # https://arxiv.org/abs/2410.01104 + # ratioonale is that if tokens get larger, categorical distribution get more uniform + # so you want to enlargen entropy. + + token_length = q.shape[2] + if self.dynamic_softmax_temperature: + ratio = math.sqrt(math.log(token_length) / math.log(1040.0)) # 1024 + 16 + k = k * ratio + q, k = self.qk_norm(q, k) + + else: + q = rearrange(self.q(x), "b l (h d) -> b h l d", h=self.num_heads) + kv = rearrange( + self.context_kv(context), + "b l (k h d) -> k b h l d", + k=2, + h=self.num_heads, + ) + k, v = kv.unbind(0) + q, k = self.qk_norm(q, k) + + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b h l d -> b l (h d)") + x = self.proj(x) + return x, v if self.is_self_attn else None + + +class DiTBlock(nn.Module): + def __init__( + self, + hidden_size, + cross_attn_input_size, + num_heads, + mlp_ratio=4.0, + qkv_bias=True, + residual_v=False, + dynamic_softmax_temperature=False, + ): + super().__init__() + self.hidden_size = hidden_size + self.norm1 = RMSNorm(hidden_size, trainable=qkv_bias) + self.self_attn = Attention( + hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + is_self_attn=True, + residual_v=residual_v, + dynamic_softmax_temperature=dynamic_softmax_temperature, + ) + + if cross_attn_input_size is not None: + self.norm2 = RMSNorm(hidden_size, trainable=qkv_bias) + self.cross_attn = Attention( + hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + is_self_attn=False, + cross_attn_input_size=cross_attn_input_size, + dynamic_softmax_temperature=dynamic_softmax_temperature, + ) + else: + self.norm2 = None + self.cross_attn = None + + self.norm3 = RMSNorm(hidden_size, trainable=qkv_bias) + mlp_hidden = int(hidden_size * mlp_ratio) + self.mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden), + nn.GELU(), + nn.Linear(mlp_hidden, hidden_size), + ) + + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 9 * hidden_size, bias=True)) + + self.adaLN_modulation[-1].weight.data.zero_() + self.adaLN_modulation[-1].bias.data.zero_() + + # @torch.compile(mode='reduce-overhead') + def forward(self, x, context, c, v_0=None, rope=None): + ( + shift_sa, + scale_sa, + gate_sa, + shift_ca, + scale_ca, + gate_ca, + shift_mlp, + scale_mlp, + gate_mlp, + ) = self.adaLN_modulation(c).chunk(9, dim=1) + + scale_sa = scale_sa[:, None, :] + scale_ca = scale_ca[:, None, :] + scale_mlp = scale_mlp[:, None, :] + + shift_sa = shift_sa[:, None, :] + shift_ca = shift_ca[:, None, :] + shift_mlp = shift_mlp[:, None, :] + + gate_sa = gate_sa[:, None, :] + gate_ca = gate_ca[:, None, :] + gate_mlp = gate_mlp[:, None, :] + + norm_x = self.norm1(x.clone()) + norm_x = norm_x * (1 + scale_sa) + shift_sa + attn_out, v = self.self_attn(norm_x, v_0=v_0, rope=rope) + x = x + attn_out * gate_sa + + if self.norm2 is not None: + norm_x = self.norm2(x) + norm_x = norm_x * (1 + scale_ca) + shift_ca + x = x + self.cross_attn(norm_x, context)[0] * gate_ca + + norm_x = self.norm3(x) + norm_x = norm_x * (1 + scale_mlp) + shift_mlp + x = x + self.mlp(norm_x) * gate_mlp + + return x, v + + +class PatchEmbed(nn.Module): + def __init__(self, patch_size=16, in_channels=3, embed_dim=768): + super().__init__() + self.patch_proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + self.patch_size = patch_size + + def forward(self, x): + B, C, H, W = x.shape + x = self.patch_proj(x) + x = rearrange(x, "b c h w -> b (h w) c") + return x + + +class TwoDimRotary(torch.nn.Module): + def __init__(self, dim, base=10000, h=256, w=256): + super().__init__() + self.inv_freq = torch.FloatTensor([1.0 / (base ** (i / dim)) for i in range(0, dim, 2)]) + self.h = h + self.w = w + + t_h = torch.arange(h, dtype=torch.float32) + t_w = torch.arange(w, dtype=torch.float32) + + freqs_h = torch.outer(t_h, self.inv_freq).unsqueeze(1) # h, 1, d / 2 + freqs_w = torch.outer(t_w, self.inv_freq).unsqueeze(0) # 1, w, d / 2 + freqs_h = freqs_h.repeat(1, w, 1) # h, w, d / 2 + freqs_w = freqs_w.repeat(h, 1, 1) # h, w, d / 2 + freqs_hw = torch.cat([freqs_h, freqs_w], 2) # h, w, d + + self.register_buffer("freqs_hw_cos", freqs_hw.cos()) + self.register_buffer("freqs_hw_sin", freqs_hw.sin()) + + def forward(self, x, height_width=None, extend_with_register_tokens=0): + if height_width is not None: + this_h, this_w = height_width + else: + this_hw = x.shape[1] + this_h, this_w = int(this_hw**0.5), int(this_hw**0.5) + + cos = self.freqs_hw_cos[0 : this_h, 0 : this_w] + sin = self.freqs_hw_sin[0 : this_h, 0 : this_w] + + cos = cos.clone().reshape(this_h * this_w, -1) + sin = sin.clone().reshape(this_h * this_w, -1) + + # append N of zero-attn tokens + if extend_with_register_tokens > 0: + cos = torch.cat( + [ + torch.ones(extend_with_register_tokens, cos.shape[1]).to(cos.device), + cos, + ], + 0, + ) + sin = torch.cat( + [ + torch.zeros(extend_with_register_tokens, sin.shape[1]).to(sin.device), + sin, + ], + 0, + ) + + return cos[None, None, :, :], sin[None, None, :, :] # [1, 1, T + N, Attn-dim] + + +def apply_rotary_emb(x, cos, sin): + orig_dtype = x.dtype + x = x.to(dtype=torch.float32) + assert x.ndim == 4 # multihead attention + d = x.shape[3] // 2 + x1 = x[..., :d] + x2 = x[..., d:] + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat([y1, y2], 3).to(dtype=orig_dtype) + + +class DiT(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): # type: ignore[misc] + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels=4, + patch_size=2, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + cross_attn_input_size=128, + residual_v=False, + train_bias_and_rms=True, + use_rope=True, + gradient_checkpoint=False, + dynamic_softmax_temperature=False, + rope_base=10000, + ): + super().__init__() + + self.patch_embed = PatchEmbed(patch_size, in_channels, hidden_size) + + if use_rope: + self.rope = TwoDimRotary(hidden_size // (2 * num_heads), base=rope_base, h=512, w=512) + else: + self.positional_embedding = nn.Parameter(torch.zeros(1, 2048, hidden_size)) + + self.register_tokens = nn.Parameter(torch.randn(1, 16, hidden_size)) + + self.time_embed = nn.Sequential( + nn.Linear(hidden_size, 4 * hidden_size), + nn.SiLU(), + nn.Linear(4 * hidden_size, hidden_size), + ) + + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + cross_attn_input_size=cross_attn_input_size, + residual_v=residual_v, + qkv_bias=train_bias_and_rms, + dynamic_softmax_temperature=dynamic_softmax_temperature, + ) + for _ in range(depth) + ] + ) + + self.final_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + self.final_norm = RMSNorm(hidden_size, trainable=train_bias_and_rms) + self.final_proj = nn.Linear(hidden_size, patch_size * patch_size * in_channels) + nn.init.zeros_(self.final_modulation[-1].weight) + nn.init.zeros_(self.final_modulation[-1].bias) + nn.init.zeros_(self.final_proj.weight) + nn.init.zeros_(self.final_proj.bias) + self.paramstatus = {} + for n, p in self.named_parameters(): + self.paramstatus[n] = { + "shape": p.shape, + "requires_grad": p.requires_grad, + } + self.gradient_checkpointing = False + + def save_lora_weights(self, save_directory): + """Save LoRA weights to a file""" + lora_state_dict = get_peft_model_state_dict(self) + torch.save(lora_state_dict, f"{save_directory}/lora_weights.pt") + + def load_lora_weights(self, load_directory): + """Load LoRA weights from a file""" + lora_state_dict = torch.load(f"{load_directory}/lora_weights.pt") + set_peft_model_state_dict(self, lora_state_dict) + + @apply_forward_hook + def forward(self, x, context, timesteps): + b, c, h, w = x.shape + x = self.patch_embed(x) # b, T, d + + x = torch.cat([self.register_tokens.repeat(b, 1, 1), x], 1) # b, T + N, d + + if self.config.use_rope: + cos, sin = self.rope( + x, + extend_with_register_tokens=16, + height_width=(h // self.config.patch_size, w // self.config.patch_size), + ) + else: + x = x + self.positional_embedding.repeat(b, 1, 1)[:, : x.shape[1], :] + cos, sin = None, None + + t_emb = timestep_embedding(timesteps * 1000, self.config.hidden_size).to(x.device, dtype=x.dtype) + t_emb = self.time_embed(t_emb) + + v_0 = None + + for _idx, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + x, v = self._gradient_checkpointing_func( + block, + x, + context, + t_emb, + v_0, + (cos, sin) + ) + else: + x, v = block(x, context, t_emb, v_0, (cos, sin)) + if v_0 is None: + v_0 = v + + x = x[:, 16:, :] + final_shift, final_scale = self.final_modulation(t_emb).chunk(2, dim=1) + x = self.final_norm(x) + x = x * (1 + final_scale[:, None, :]) + final_shift[:, None, :] + x = self.final_proj(x) + + x = rearrange( + x, + "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", + h=h // self.config.patch_size, + w=w // self.config.patch_size, + p1=self.config.patch_size, + p2=self.config.patch_size, + ) + return x + + +if __name__ == "__main__": + model = DiT( + in_channels=4, + patch_size=2, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + cross_attn_input_size=128, + residual_v=False, + train_bias_and_rms=True, + use_rope=True, + ).cuda() + print( + model( + torch.randn(1, 4, 64, 64).cuda(), + torch.randn(1, 37, 128).cuda(), + torch.tensor([1.0]).cuda(), + ) + ) \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/f_light/src/pipeline.py b/extensions_built_in/diffusion_models/f_light/src/pipeline.py new file mode 100644 index 00000000..69fc67b9 --- /dev/null +++ b/extensions_built_in/diffusion_models/f_light/src/pipeline.py @@ -0,0 +1,308 @@ +# originally from https://github.com/fal-ai/f-lite/blob/main/f_lite/pipeline.py but modified slightly +import logging +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers import AutoencoderKL, DiffusionPipeline +from diffusers.utils import BaseOutput +from diffusers.utils.torch_utils import randn_tensor +from PIL import Image +from torch import FloatTensor +from tqdm.auto import tqdm +from transformers import T5EncoderModel, T5TokenizerFast + + + +logger = logging.getLogger(__name__) + + +@dataclass +class APGConfig: + """APG (Augmented Parallel Guidance) configuration""" + + enabled: bool = True + orthogonal_threshold: float = 0.03 + + +@dataclass +class FLitePipelineOutput(BaseOutput): + """ + Output class for FLitePipeline pipeline. + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[Image.Image], np.ndarray] + + +class FLitePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using F-Lite model. + This model inherits from [`DiffusionPipeline`]. + """ + + model_cpu_offload_seq = "text_encoder->dit_model->vae" + + dit_model: torch.nn.Module + vae: AutoencoderKL + text_encoder: T5EncoderModel + tokenizer: T5TokenizerFast + _progress_bar_config: Dict[str, Any] + + def __init__( + self, dit_model: torch.nn.Module, vae: AutoencoderKL, text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast + ): + super().__init__() + # Register all modules for the pipeline + # Access DiffusionPipeline's register_modules directly to avoid mypy error + DiffusionPipeline.register_modules( + self, dit_model=dit_model, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer + ) + + # Move models to channels last for better performance + # AutoencoderKL inherits from torch.nn.Module which has these methods + if hasattr(self.vae, "to"): + self.vae.to(memory_format=torch.channels_last) + if hasattr(self.vae, "requires_grad_"): + self.vae.requires_grad_(False) + if hasattr(self.text_encoder, "requires_grad_"): + self.text_encoder.requires_grad_(False) + + # Constants + self.vae_scale_factor = 8 + self.return_index = -8 # T5 hidden state index to use + + def enable_vae_slicing(self): + """Enable VAE slicing for memory efficiency.""" + if hasattr(self.vae, "enable_slicing"): + self.vae.enable_slicing() + + def enable_vae_tiling(self): + """Enable VAE tiling for memory efficiency.""" + if hasattr(self.vae, "enable_tiling"): + self.vae.enable_tiling() + + def set_progress_bar_config(self, **kwargs): + """Set progress bar configuration.""" + self._progress_bar_config = kwargs + + def progress_bar(self, iterable=None, **kwargs): + """Create progress bar for iterations.""" + self._progress_bar_config = getattr(self, "_progress_bar_config", None) or {} + config = {**self._progress_bar_config, **kwargs} + return tqdm(iterable, **config) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 512, + return_index: int = -8, + ) -> Tuple[FloatTensor, FloatTensor]: + """Encodes the prompt and negative prompt.""" + if isinstance(prompt, str): + prompt = [prompt] + device = device or self.text_encoder.device + # Text encoder forward pass + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + prompt_embeds = self.text_encoder(text_input_ids, return_dict=True, output_hidden_states=True) + prompt_embeds_tensor = prompt_embeds.hidden_states[return_index] + if return_index != -1: + prompt_embeds_tensor = self.text_encoder.encoder.final_layer_norm(prompt_embeds_tensor) + prompt_embeds_tensor = self.text_encoder.encoder.dropout(prompt_embeds_tensor) + + dtype = dtype or next(self.text_encoder.parameters()).dtype + prompt_embeds_tensor = prompt_embeds_tensor.to(dtype=dtype, device=device) + + # Handle negative prompts + if negative_prompt is None: + negative_embeds = torch.zeros_like(prompt_embeds_tensor) + else: + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + negative_result = self.encode_prompt( + prompt=negative_prompt, device=device, dtype=dtype, return_index=return_index + ) + negative_embeds = negative_result[0] + + # Explicitly cast both tensors to FloatTensor for mypy + from typing import cast + + prompt_tensor = cast(FloatTensor, prompt_embeds_tensor.to(dtype=dtype)) + negative_tensor = cast(FloatTensor, negative_embeds.to(dtype=dtype)) + return (prompt_tensor, negative_tensor) + + def to(self, torch_device=None, torch_dtype=None, silence_dtype_warnings=False): + """Move pipeline components to specified device and dtype.""" + if hasattr(self, "vae"): + self.vae.to(device=torch_device, dtype=torch_dtype) + if hasattr(self, "text_encoder"): + self.text_encoder.to(device=torch_device, dtype=torch_dtype) + if hasattr(self, "dit_model"): + self.dit_model.to(device=torch_device, dtype=torch_dtype) + return self + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]]=None, + prompt_embeds: Optional[FloatTensor] = None, + height: Optional[int] = 1024, + width: Optional[int] = 1024, + num_inference_steps: int = 30, + guidance_scale: float = 6.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_embeds: Optional[FloatTensor] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + dtype: Optional[torch.dtype] = None, + alpha: Optional[float] = None, + apg_config: Optional[APGConfig] = None, + **kwargs, + ): + """Generate images from text prompt.""" + # Ensure height and width are not None for calculation + if height is None: + height = 1024 + if width is None: + width = 1024 + + dtype = dtype or next(self.dit_model.parameters()).dtype + apg_config = apg_config or APGConfig(enabled=False) + + device = self._execution_device + + # 2. Encode prompts + prompt_batch_size = len(prompt) if isinstance(prompt, list) else 1 + batch_size = prompt_batch_size * num_images_per_prompt + + if prompt_embeds is None or negative_prompt_embeds is None: + prompt_embeds, negative_embeds = self.encode_prompt( + prompt=prompt, negative_prompt=negative_prompt, device=self.text_encoder.device, dtype=dtype, + return_index=self.return_index, + ) + else: + negative_embeds = negative_prompt_embeds + + # Repeat embeddings for num_images_per_prompt + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_embeds = negative_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + # 3. Initialize latents + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError(f"Got {len(generator)} generators for {batch_size} samples") + + latents = randn_tensor((batch_size, 16, latent_height, latent_width), generator=generator, device=device, dtype=dtype) + acc_latents = latents.clone() + + # 4. Calculate alpha if not provided + if alpha is None: + image_token_size = latent_height * latent_width + alpha = 2 * math.sqrt(image_token_size / (64 * 64)) + + # 6. Sampling loop + self.dit_model.eval() + + # Check if guidance is needed + do_classifier_free_guidance = guidance_scale >= 1.0 + + for i in self.progress_bar(range(num_inference_steps, 0, -1)): + # Calculate timesteps + t = i / num_inference_steps + t_next = (i - 1) / num_inference_steps + # Scale timesteps according to alpha + t = t * alpha / (1 + (alpha - 1) * t) + t_next = t_next * alpha / (1 + (alpha - 1) * t_next) + dt = t - t_next + + # Create tensor with proper device + t_tensor = torch.tensor([t] * batch_size, device=device, dtype=dtype) + + if do_classifier_free_guidance: + # Duplicate latents for both conditional and unconditional inputs + latents_input = torch.cat([latents] * 2) + # Concatenate negative and positive prompt embeddings + context_input = torch.cat([negative_embeds, prompt_embeds]) + # Duplicate timesteps for the batch + t_input = torch.cat([t_tensor] * 2) + + # Get model predictions in a single pass + model_outputs = self.dit_model(latents_input, context_input, t_input) + + # Split outputs back into unconditional and conditional predictions + uncond_output, cond_output = model_outputs.chunk(2) + + if apg_config.enabled: + # Augmented Parallel Guidance + dy = cond_output + dd = cond_output - uncond_output + # Find parallel direction + parallel_direction = (dy * dd).sum() / (dy * dy).sum() * dy + orthogonal_direction = dd - parallel_direction + # Scale orthogonal component + orthogonal_std = orthogonal_direction.std() + orthogonal_scale = min(1, apg_config.orthogonal_threshold / orthogonal_std) + orthogonal_direction = orthogonal_direction * orthogonal_scale + model_output = dy + (guidance_scale - 1) * orthogonal_direction + else: + # Standard classifier-free guidance + model_output = uncond_output + guidance_scale * (cond_output - uncond_output) + else: + # If no guidance needed, just run the model normally + model_output = self.dit_model(latents, prompt_embeds, t_tensor) + + # Update latents + acc_latents = acc_latents + dt * model_output.to(device) + latents = acc_latents.clone() + + # 7. Decode latents + # These checks handle the case where mypy doesn't recognize these attributes + scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) if hasattr(self.vae, "config") else 0.18215 + shift_factor = getattr(self.vae.config, "shift_factor", 0) if hasattr(self.vae, "config") else 0 + + latents = latents / scaling_factor + shift_factor + + vae_dtype = self.vae.dtype if hasattr(self.vae, "dtype") else dtype + decoded_images = self.vae.decode(latents.to(vae_dtype)).sample if hasattr(self.vae, "decode") else latents + + # Offload all models + try: + self.maybe_free_model_hooks() + except AttributeError as e: + if "OptimizedModule" in str(e): + import warnings + warnings.warn( + "Encountered 'OptimizedModule' error when offloading models. " + "This issue might be fixed in the future by: " + "https://github.com/huggingface/diffusers/pull/10730" + ) + else: + raise + + # 8. Post-process images + images = (decoded_images / 2 + 0.5).clamp(0, 1) + # Convert to PIL Images + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu() + pil_images = [Image.fromarray(img.permute(1, 2, 0).numpy()) for img in images] + + return FLitePipelineOutput( + images=pil_images, + ) \ No newline at end of file diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 00623fa0..2ea29276 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -249,7 +249,8 @@ class DiffusionFeatureExtractor3(nn.Module): # lpips_weight=1.0, lpips_weight=10.0, clip_weight=0.1, - pixel_weight=0.1 + pixel_weight=0.1, + model=None ): dtype = torch.bfloat16 device = self.vae.device @@ -274,7 +275,10 @@ class DiffusionFeatureExtractor3(nn.Module): # stepped_latents = torch.cat(stepped_chunks, dim=0) - stepped_latents = noise - noise_pred + if model is not None and hasattr(model, 'get_stepped_pred'): + stepped_latents = model.get_stepped_pred(noise_pred, noise) + else: + stepped_latents = noise - noise_pred latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 4ada3896..d35fc09e 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -2283,6 +2283,7 @@ class StableDiffusion: bleed_latents: torch.FloatTensor = None, is_input_scaled=False, return_first_prediction=False, + bypass_guidance_embedding=False, **kwargs, ): timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps] @@ -2299,6 +2300,7 @@ class StableDiffusion: add_time_ids=add_time_ids, is_input_scaled=is_input_scaled, return_conditional_pred=True, + bypass_guidance_embedding=bypass_guidance_embedding, **kwargs, ) # some schedulers need to run separately, so do that. (euler for example) diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index e7c50beb..78e2183c 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -145,7 +145,7 @@ if TYPE_CHECKING: def concat_prompt_embeddings( unconditional: 'PromptEmbeds', conditional: 'PromptEmbeds', - n_imgs: int, + n_imgs: int=0, ): from toolkit.stable_diffusion_model import PromptEmbeds text_embeds = torch.cat(