mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
@@ -5,6 +5,7 @@ from .omnigen2 import OmniGen2Model
|
|||||||
from .flux_kontext import FluxKontextModel
|
from .flux_kontext import FluxKontextModel
|
||||||
from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel
|
from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel
|
||||||
from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel
|
from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel
|
||||||
|
from .flux2 import Flux2Model
|
||||||
|
|
||||||
AI_TOOLKIT_MODELS = [
|
AI_TOOLKIT_MODELS = [
|
||||||
# put a list of models here
|
# put a list of models here
|
||||||
@@ -21,4 +22,5 @@ AI_TOOLKIT_MODELS = [
|
|||||||
QwenImageModel,
|
QwenImageModel,
|
||||||
QwenImageEditModel,
|
QwenImageEditModel,
|
||||||
QwenImageEditPlusModel,
|
QwenImageEditPlusModel,
|
||||||
|
Flux2Model,
|
||||||
]
|
]
|
||||||
|
|||||||
1
extensions_built_in/diffusion_models/flux2/__init__.py
Normal file
1
extensions_built_in/diffusion_models/flux2/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .flux2_model import Flux2Model
|
||||||
490
extensions_built_in/diffusion_models/flux2/flux2_model.py
Normal file
490
extensions_built_in/diffusion_models/flux2/flux2_model.py
Normal file
@@ -0,0 +1,490 @@
|
|||||||
|
import math
|
||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
|
import huggingface_hub
|
||||||
|
import torch
|
||||||
|
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||||
|
from toolkit.memory_management.manager import MemoryManager
|
||||||
|
from toolkit.metadata import get_meta_for_safetensors
|
||||||
|
from toolkit.models.base_model import BaseModel
|
||||||
|
from toolkit.basic import flush
|
||||||
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
|
from toolkit.samplers.custom_flowmatch_sampler import (
|
||||||
|
CustomFlowMatchEulerDiscreteScheduler,
|
||||||
|
)
|
||||||
|
from toolkit.dequantize import patch_dequantization_on_save
|
||||||
|
from toolkit.accelerator import unwrap_model
|
||||||
|
from optimum.quanto import freeze, QTensor
|
||||||
|
from toolkit.util.quantize import quantize, get_qtype, quantize_model
|
||||||
|
|
||||||
|
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||||
|
from .src.model import Flux2, Flux2Params
|
||||||
|
from .src.pipeline import Flux2Pipeline
|
||||||
|
from .src.autoencoder import AutoEncoder, AutoEncoderParams
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
from PIL import Image
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||||
|
|
||||||
|
from .src.sampling import (
|
||||||
|
batched_prc_img,
|
||||||
|
batched_prc_txt,
|
||||||
|
encode_image_refs,
|
||||||
|
scatter_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler_config = {
|
||||||
|
"base_image_seq_len": 256,
|
||||||
|
"base_shift": 0.5,
|
||||||
|
"max_image_seq_len": 4096,
|
||||||
|
"max_shift": 1.15,
|
||||||
|
"num_train_timesteps": 1000,
|
||||||
|
"shift": 3.0,
|
||||||
|
"use_dynamic_shifting": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
MISTRAL_PATH = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||||
|
FLUX2_VAE_FILENAME = "ae.safetensors"
|
||||||
|
FLUX2_TRANSFORMER_FILENAME = "flux2-dev.safetensors"
|
||||||
|
|
||||||
|
HF_TOKEN = os.getenv("HF_TOKEN", None)
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2Model(BaseModel):
|
||||||
|
arch = "flux2"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
dtype="bf16",
|
||||||
|
custom_pipeline=None,
|
||||||
|
noise_scheduler=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
|
||||||
|
)
|
||||||
|
self.is_flow_matching = True
|
||||||
|
self.is_transformer = True
|
||||||
|
self.target_lora_modules = ["Flux2"]
|
||||||
|
# control images will come in as a list for encoding some things if true
|
||||||
|
self.has_multiple_control_images = True
|
||||||
|
# do not resize control images
|
||||||
|
self.use_raw_control_images = True
|
||||||
|
|
||||||
|
# static method to get the noise scheduler
|
||||||
|
@staticmethod
|
||||||
|
def get_train_scheduler():
|
||||||
|
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||||
|
|
||||||
|
def get_bucket_divisibility(self):
|
||||||
|
return 16
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
dtype = self.torch_dtype
|
||||||
|
self.print_and_status_update("Loading Flux2 model")
|
||||||
|
# will be updated if we detect a existing checkpoint in training folder
|
||||||
|
model_path = self.model_config.name_or_path
|
||||||
|
transformer_path = model_path
|
||||||
|
|
||||||
|
self.print_and_status_update("Loading transformer")
|
||||||
|
with torch.device("meta"):
|
||||||
|
transformer = Flux2(Flux2Params())
|
||||||
|
|
||||||
|
# use local path if provided
|
||||||
|
if os.path.exists(os.path.join(transformer_path, FLUX2_TRANSFORMER_FILENAME)):
|
||||||
|
transformer_path = os.path.join(
|
||||||
|
transformer_path, FLUX2_TRANSFORMER_FILENAME
|
||||||
|
)
|
||||||
|
|
||||||
|
if not os.path.exists(transformer_path):
|
||||||
|
# assume it is from the hub
|
||||||
|
transformer_path = huggingface_hub.hf_hub_download(
|
||||||
|
repo_id=model_path,
|
||||||
|
filename=FLUX2_TRANSFORMER_FILENAME,
|
||||||
|
token=HF_TOKEN,
|
||||||
|
)
|
||||||
|
|
||||||
|
transformer_state_dict = load_file(transformer_path, device="cpu")
|
||||||
|
|
||||||
|
# cast to dtype
|
||||||
|
for key in transformer_state_dict:
|
||||||
|
transformer_state_dict[key] = transformer_state_dict[key].to(dtype)
|
||||||
|
|
||||||
|
transformer.load_state_dict(transformer_state_dict, assign=True)
|
||||||
|
|
||||||
|
transformer.to(self.quantize_device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.model_config.quantize:
|
||||||
|
# patch the state dict method
|
||||||
|
patch_dequantization_on_save(transformer)
|
||||||
|
self.print_and_status_update("Quantizing Transformer")
|
||||||
|
quantize_model(self, transformer)
|
||||||
|
flush()
|
||||||
|
else:
|
||||||
|
transformer.to(self.device_torch, dtype=dtype)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.model_config.layer_offloading
|
||||||
|
and self.model_config.layer_offloading_transformer_percent > 0
|
||||||
|
):
|
||||||
|
MemoryManager.attach(
|
||||||
|
transformer,
|
||||||
|
self.device_torch,
|
||||||
|
offload_percent=self.model_config.layer_offloading_transformer_percent,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.model_config.low_vram:
|
||||||
|
self.print_and_status_update("Moving transformer to CPU")
|
||||||
|
transformer.to("cpu")
|
||||||
|
|
||||||
|
self.print_and_status_update("Loading Mistral")
|
||||||
|
|
||||||
|
text_encoder: Mistral3ForConditionalGeneration = (
|
||||||
|
Mistral3ForConditionalGeneration.from_pretrained(
|
||||||
|
MISTRAL_PATH,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
text_encoder.to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
|
flush()
|
||||||
|
|
||||||
|
if self.model_config.quantize_te:
|
||||||
|
self.print_and_status_update("Quantizing Mistral")
|
||||||
|
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
|
||||||
|
freeze(text_encoder)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.model_config.layer_offloading
|
||||||
|
and self.model_config.layer_offloading_text_encoder_percent > 0
|
||||||
|
):
|
||||||
|
MemoryManager.attach(
|
||||||
|
text_encoder,
|
||||||
|
self.device_torch,
|
||||||
|
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH)
|
||||||
|
|
||||||
|
self.print_and_status_update("Loading VAE")
|
||||||
|
vae_path = self.model_config.vae_path
|
||||||
|
|
||||||
|
if os.path.exists(os.path.join(model_path, FLUX2_VAE_FILENAME)):
|
||||||
|
vae_path = os.path.join(model_path, FLUX2_VAE_FILENAME)
|
||||||
|
|
||||||
|
if vae_path is None or not os.path.exists(vae_path):
|
||||||
|
# assume it is from the hub
|
||||||
|
vae_path = huggingface_hub.hf_hub_download(
|
||||||
|
repo_id=model_path,
|
||||||
|
filename=FLUX2_VAE_FILENAME,
|
||||||
|
token=HF_TOKEN,
|
||||||
|
)
|
||||||
|
with torch.device("meta"):
|
||||||
|
vae = AutoEncoder(AutoEncoderParams())
|
||||||
|
|
||||||
|
vae_state_dict = load_file(vae_path, device="cpu")
|
||||||
|
|
||||||
|
# cast to dtype
|
||||||
|
for key in vae_state_dict:
|
||||||
|
vae_state_dict[key] = vae_state_dict[key].to(dtype)
|
||||||
|
|
||||||
|
vae.load_state_dict(vae_state_dict, assign=True)
|
||||||
|
|
||||||
|
self.noise_scheduler = Flux2Model.get_train_scheduler()
|
||||||
|
|
||||||
|
self.print_and_status_update("Making pipe")
|
||||||
|
|
||||||
|
pipe: Flux2Pipeline = Flux2Pipeline(
|
||||||
|
scheduler=self.noise_scheduler,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
vae=vae,
|
||||||
|
transformer=None,
|
||||||
|
)
|
||||||
|
# for quantization, it works best to do these after making the pipe
|
||||||
|
pipe.transformer = transformer
|
||||||
|
|
||||||
|
self.print_and_status_update("Preparing Model")
|
||||||
|
|
||||||
|
text_encoder = [pipe.text_encoder]
|
||||||
|
tokenizer = [pipe.tokenizer]
|
||||||
|
|
||||||
|
flush()
|
||||||
|
# just to make sure everything is on the right device and dtype
|
||||||
|
text_encoder[0].to(self.device_torch)
|
||||||
|
text_encoder[0].requires_grad_(False)
|
||||||
|
text_encoder[0].eval()
|
||||||
|
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
# save it to the model class
|
||||||
|
self.vae = vae
|
||||||
|
self.text_encoder = text_encoder # list of text encoders
|
||||||
|
self.tokenizer = tokenizer # list of tokenizers
|
||||||
|
self.model = pipe.transformer
|
||||||
|
self.pipeline = pipe
|
||||||
|
self.print_and_status_update("Model Loaded")
|
||||||
|
|
||||||
|
def get_generation_pipeline(self):
|
||||||
|
scheduler = Flux2Model.get_train_scheduler()
|
||||||
|
|
||||||
|
pipeline: Flux2Pipeline = Flux2Pipeline(
|
||||||
|
scheduler=scheduler,
|
||||||
|
text_encoder=unwrap_model(self.text_encoder[0]),
|
||||||
|
tokenizer=self.tokenizer[0],
|
||||||
|
vae=unwrap_model(self.vae),
|
||||||
|
transformer=unwrap_model(self.transformer),
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline = pipeline.to(self.device_torch)
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
def generate_single_image(
|
||||||
|
self,
|
||||||
|
pipeline: Flux2Pipeline,
|
||||||
|
gen_config: GenerateImageConfig,
|
||||||
|
conditional_embeds: PromptEmbeds,
|
||||||
|
unconditional_embeds: PromptEmbeds,
|
||||||
|
generator: torch.Generator,
|
||||||
|
extra: dict,
|
||||||
|
):
|
||||||
|
gen_config.width = (
|
||||||
|
gen_config.width // self.get_bucket_divisibility()
|
||||||
|
) * self.get_bucket_divisibility()
|
||||||
|
gen_config.height = (
|
||||||
|
gen_config.height // self.get_bucket_divisibility()
|
||||||
|
) * self.get_bucket_divisibility()
|
||||||
|
|
||||||
|
control_img_list = []
|
||||||
|
if gen_config.ctrl_img is not None:
|
||||||
|
control_img = Image.open(gen_config.ctrl_img)
|
||||||
|
control_img = control_img.convert("RGB")
|
||||||
|
control_img_list.append(control_img)
|
||||||
|
elif gen_config.ctrl_img_1 is not None:
|
||||||
|
control_img = Image.open(gen_config.ctrl_img_1)
|
||||||
|
control_img = control_img.convert("RGB")
|
||||||
|
control_img_list.append(control_img)
|
||||||
|
if gen_config.ctrl_img_2 is not None:
|
||||||
|
control_img = Image.open(gen_config.ctrl_img_2)
|
||||||
|
control_img = control_img.convert("RGB")
|
||||||
|
control_img_list.append(control_img)
|
||||||
|
if gen_config.ctrl_img_3 is not None:
|
||||||
|
control_img = Image.open(gen_config.ctrl_img_3)
|
||||||
|
control_img = control_img.convert("RGB")
|
||||||
|
control_img_list.append(control_img)
|
||||||
|
|
||||||
|
img = pipeline(
|
||||||
|
prompt_embeds=conditional_embeds.text_embeds,
|
||||||
|
height=gen_config.height,
|
||||||
|
width=gen_config.width,
|
||||||
|
num_inference_steps=gen_config.num_inference_steps,
|
||||||
|
guidance_scale=gen_config.guidance_scale,
|
||||||
|
latents=gen_config.latents,
|
||||||
|
generator=generator,
|
||||||
|
control_img_list=control_img_list,
|
||||||
|
**extra,
|
||||||
|
).images[0]
|
||||||
|
return img
|
||||||
|
|
||||||
|
def get_noise_prediction(
|
||||||
|
self,
|
||||||
|
latent_model_input: torch.Tensor,
|
||||||
|
timestep: torch.Tensor, # 0 to 1000 scale
|
||||||
|
text_embeddings: PromptEmbeds,
|
||||||
|
guidance_embedding_scale: float,
|
||||||
|
batch: "DataLoaderBatchDTO" = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
with torch.no_grad():
|
||||||
|
txt, txt_ids = batched_prc_txt(text_embeddings.text_embeds)
|
||||||
|
packed_latents, img_ids = batched_prc_img(latent_model_input)
|
||||||
|
|
||||||
|
# prepare image conditioning if any
|
||||||
|
img_cond_seq: torch.Tensor | None = None
|
||||||
|
img_cond_seq_ids: torch.Tensor | None = None
|
||||||
|
|
||||||
|
# handle control images
|
||||||
|
if batch.control_tensor_list is not None:
|
||||||
|
batch_size, num_channels_latents, height, width = (
|
||||||
|
latent_model_input.shape
|
||||||
|
)
|
||||||
|
|
||||||
|
control_image_max_res = 1024 * 1024
|
||||||
|
if self.model_config.model_kwargs.get("match_target_res", False):
|
||||||
|
# use the current target size to set the control image res
|
||||||
|
control_image_res = (
|
||||||
|
height
|
||||||
|
* self.pipeline.vae_scale_factor
|
||||||
|
* width
|
||||||
|
* self.pipeline.vae_scale_factor
|
||||||
|
)
|
||||||
|
control_image_max_res = control_image_res
|
||||||
|
|
||||||
|
if len(batch.control_tensor_list) != batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
"Control tensor list length does not match batch size"
|
||||||
|
)
|
||||||
|
for control_tensor_list in batch.control_tensor_list:
|
||||||
|
# control tensor list is a list of tensors for this batch item
|
||||||
|
controls = []
|
||||||
|
# pack control
|
||||||
|
for control_img in control_tensor_list:
|
||||||
|
# control images are 0 - 1 scale, shape (1, ch, height, width)
|
||||||
|
control_img = control_img.to(
|
||||||
|
self.device_torch, dtype=self.torch_dtype
|
||||||
|
)
|
||||||
|
# if it is only 3 dim, add batch dim
|
||||||
|
if len(control_img.shape) == 3:
|
||||||
|
control_img = control_img.unsqueeze(0)
|
||||||
|
|
||||||
|
# resize to fit within max res while keeping aspect ratio
|
||||||
|
if self.model_config.model_kwargs.get(
|
||||||
|
"match_target_res", False
|
||||||
|
):
|
||||||
|
ratio = control_img.shape[2] / control_img.shape[3]
|
||||||
|
c_width = math.sqrt(control_image_res * ratio)
|
||||||
|
c_height = c_width / ratio
|
||||||
|
|
||||||
|
c_width = round(c_width / 32) * 32
|
||||||
|
c_height = round(c_height / 32) * 32
|
||||||
|
|
||||||
|
control_img = F.interpolate(
|
||||||
|
control_img, size=(c_height, c_width), mode="bilinear"
|
||||||
|
)
|
||||||
|
|
||||||
|
# scale to -1 to 1
|
||||||
|
control_img = control_img * 2 - 1
|
||||||
|
controls.append(control_img)
|
||||||
|
|
||||||
|
img_cond_seq_item, img_cond_seq_ids_item = encode_image_refs(
|
||||||
|
self.vae, controls, limit_pixels=control_image_max_res
|
||||||
|
)
|
||||||
|
if img_cond_seq is None:
|
||||||
|
img_cond_seq = img_cond_seq_item
|
||||||
|
img_cond_seq_ids = img_cond_seq_ids_item
|
||||||
|
else:
|
||||||
|
img_cond_seq = torch.cat(
|
||||||
|
(img_cond_seq, img_cond_seq_item), dim=0
|
||||||
|
)
|
||||||
|
img_cond_seq_ids = torch.cat(
|
||||||
|
(img_cond_seq_ids, img_cond_seq_ids_item), dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
img_input = packed_latents
|
||||||
|
img_input_ids = img_ids
|
||||||
|
|
||||||
|
if img_cond_seq is not None:
|
||||||
|
assert img_cond_seq_ids is not None, (
|
||||||
|
"You need to provide either both or neither of the sequence conditioning"
|
||||||
|
)
|
||||||
|
img_input = torch.cat((img_input, img_cond_seq), dim=1)
|
||||||
|
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
|
||||||
|
|
||||||
|
guidance_vec = torch.full(
|
||||||
|
(img_input.shape[0],),
|
||||||
|
guidance_embedding_scale,
|
||||||
|
device=img_input.device,
|
||||||
|
dtype=img_input.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
cast_dtype = self.model.dtype
|
||||||
|
|
||||||
|
packed_noise_pred = self.transformer(
|
||||||
|
x=img_input.to(self.device_torch, cast_dtype),
|
||||||
|
x_ids=img_input_ids.to(self.device_torch),
|
||||||
|
timesteps=timestep.to(self.device_torch, cast_dtype) / 1000,
|
||||||
|
ctx=txt.to(self.device_torch, cast_dtype),
|
||||||
|
ctx_ids=txt_ids.to(self.device_torch),
|
||||||
|
guidance=guidance_vec.to(self.device_torch, cast_dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
if img_cond_seq is not None:
|
||||||
|
packed_noise_pred = packed_noise_pred[:, : packed_latents.shape[1]]
|
||||||
|
|
||||||
|
if isinstance(packed_noise_pred, QTensor):
|
||||||
|
packed_noise_pred = packed_noise_pred.dequantize()
|
||||||
|
|
||||||
|
noise_pred = torch.cat(scatter_ids(packed_noise_pred, img_ids)).squeeze(2)
|
||||||
|
|
||||||
|
return noise_pred
|
||||||
|
|
||||||
|
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||||
|
if self.pipeline.text_encoder.device != self.device_torch:
|
||||||
|
self.pipeline.text_encoder.to(self.device_torch)
|
||||||
|
|
||||||
|
prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt(
|
||||||
|
prompt, device=self.device_torch
|
||||||
|
)
|
||||||
|
pe = PromptEmbeds(prompt_embeds)
|
||||||
|
return pe
|
||||||
|
|
||||||
|
def get_model_has_grad(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_te_has_grad(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def save_model(self, output_path, meta, save_dtype):
|
||||||
|
if not output_path.endswith(".safetensors"):
|
||||||
|
output_path = output_path + ".safetensors"
|
||||||
|
# only save the unet
|
||||||
|
transformer: Flux2 = unwrap_model(self.model)
|
||||||
|
state_dict = transformer.state_dict()
|
||||||
|
save_dict = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if isinstance(v, QTensor):
|
||||||
|
v = v.dequantize()
|
||||||
|
save_dict[k] = v.clone().to("cpu", dtype=save_dtype)
|
||||||
|
|
||||||
|
meta = get_meta_for_safetensors(meta, name="flux2")
|
||||||
|
save_file(save_dict, output_path, metadata=meta)
|
||||||
|
|
||||||
|
def get_loss_target(self, *args, **kwargs):
|
||||||
|
noise = kwargs.get("noise")
|
||||||
|
batch = kwargs.get("batch")
|
||||||
|
return (noise - batch.latents).detach()
|
||||||
|
|
||||||
|
def get_base_model_version(self):
|
||||||
|
return "flux2"
|
||||||
|
|
||||||
|
def get_transformer_block_names(self) -> Optional[List[str]]:
|
||||||
|
return ["double_blocks", "single_blocks"]
|
||||||
|
|
||||||
|
def convert_lora_weights_before_save(self, state_dict):
|
||||||
|
new_sd = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
new_key = key.replace("transformer.", "diffusion_model.")
|
||||||
|
new_sd[new_key] = value
|
||||||
|
return new_sd
|
||||||
|
|
||||||
|
def convert_lora_weights_before_load(self, state_dict):
|
||||||
|
new_sd = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
new_key = key.replace("diffusion_model.", "transformer.")
|
||||||
|
new_sd[new_key] = value
|
||||||
|
return new_sd
|
||||||
|
|
||||||
|
def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None):
|
||||||
|
if device is None:
|
||||||
|
device = self.vae_device_torch
|
||||||
|
if dtype is None:
|
||||||
|
dtype = self.vae_torch_dtype
|
||||||
|
|
||||||
|
# Move to vae to device if on cpu
|
||||||
|
if self.vae.device == torch.device("cpu"):
|
||||||
|
self.vae.to(device)
|
||||||
|
# move to device and dtype
|
||||||
|
image_list = [image.to(device, dtype=dtype) for image in image_list]
|
||||||
|
images = torch.stack(image_list).to(device, dtype=dtype)
|
||||||
|
|
||||||
|
latents = self.vae.encode(images)
|
||||||
|
|
||||||
|
return latents
|
||||||
370
extensions_built_in/diffusion_models/flux2/src/autoencoder.py
Normal file
370
extensions_built_in/diffusion_models/flux2/src/autoencoder.py
Normal file
@@ -0,0 +1,370 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import Tensor, nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AutoEncoderParams:
|
||||||
|
resolution: int = 256
|
||||||
|
in_channels: int = 3
|
||||||
|
ch: int = 128
|
||||||
|
out_ch: int = 3
|
||||||
|
ch_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4])
|
||||||
|
num_res_blocks: int = 2
|
||||||
|
z_channels: int = 32
|
||||||
|
|
||||||
|
|
||||||
|
def swish(x: Tensor) -> Tensor:
|
||||||
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
class AttnBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.norm = nn.GroupNorm(
|
||||||
|
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||||
|
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||||
|
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||||
|
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||||
|
|
||||||
|
def attention(self, h_: Tensor) -> Tensor:
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q = self.q(h_)
|
||||||
|
k = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
|
||||||
|
b, c, h, w = q.shape
|
||||||
|
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
||||||
|
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
||||||
|
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
||||||
|
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
||||||
|
|
||||||
|
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x + self.proj_out(self.attention(x))
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels: int, out_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.norm1 = nn.GroupNorm(
|
||||||
|
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||||
|
)
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
self.norm2 = nn.GroupNorm(
|
||||||
|
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
|
||||||
|
)
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
self.nin_shortcut = nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = x
|
||||||
|
h = self.norm1(h)
|
||||||
|
h = swish(h)
|
||||||
|
h = self.conv1(h)
|
||||||
|
|
||||||
|
h = self.norm2(h)
|
||||||
|
h = swish(h)
|
||||||
|
h = self.conv2(h)
|
||||||
|
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
x = self.nin_shortcut(x)
|
||||||
|
|
||||||
|
return x + h
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
def __init__(self, in_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
pad = (0, 1, 0, 1)
|
||||||
|
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample(nn.Module):
|
||||||
|
def __init__(self, in_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
resolution: int,
|
||||||
|
in_channels: int,
|
||||||
|
ch: int,
|
||||||
|
ch_mult: list[int],
|
||||||
|
num_res_blocks: int,
|
||||||
|
z_channels: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1)
|
||||||
|
self.ch = ch
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
# downsampling
|
||||||
|
self.conv_in = nn.Conv2d(
|
||||||
|
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
curr_res = resolution
|
||||||
|
in_ch_mult = (1,) + tuple(ch_mult)
|
||||||
|
self.in_ch_mult = in_ch_mult
|
||||||
|
self.down = nn.ModuleList()
|
||||||
|
block_in = self.ch
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_in = ch * in_ch_mult[i_level]
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
for _ in range(self.num_res_blocks):
|
||||||
|
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||||
|
block_in = block_out
|
||||||
|
down = nn.Module()
|
||||||
|
down.block = block
|
||||||
|
down.attn = attn
|
||||||
|
if i_level != self.num_resolutions - 1:
|
||||||
|
down.downsample = Downsample(block_in)
|
||||||
|
curr_res = curr_res // 2
|
||||||
|
self.down.append(down)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||||
|
self.mid.attn_1 = AttnBlock(block_in)
|
||||||
|
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = nn.GroupNorm(
|
||||||
|
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
|
||||||
|
)
|
||||||
|
self.conv_out = nn.Conv2d(
|
||||||
|
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
# downsampling
|
||||||
|
hs = [self.conv_in(x)]
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
for i_block in range(self.num_res_blocks):
|
||||||
|
h = self.down[i_level].block[i_block](hs[-1])
|
||||||
|
if len(self.down[i_level].attn) > 0:
|
||||||
|
h = self.down[i_level].attn[i_block](h)
|
||||||
|
hs.append(h)
|
||||||
|
if i_level != self.num_resolutions - 1:
|
||||||
|
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||||
|
|
||||||
|
# middle
|
||||||
|
h = hs[-1]
|
||||||
|
h = self.mid.block_1(h)
|
||||||
|
h = self.mid.attn_1(h)
|
||||||
|
h = self.mid.block_2(h)
|
||||||
|
# end
|
||||||
|
h = self.norm_out(h)
|
||||||
|
h = swish(h)
|
||||||
|
h = self.conv_out(h)
|
||||||
|
h = self.quant_conv(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ch: int,
|
||||||
|
out_ch: int,
|
||||||
|
ch_mult: list[int],
|
||||||
|
num_res_blocks: int,
|
||||||
|
in_channels: int,
|
||||||
|
resolution: int,
|
||||||
|
z_channels: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
|
||||||
|
self.ch = ch
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.ffactor = 2 ** (self.num_resolutions - 1)
|
||||||
|
|
||||||
|
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||||
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
|
|
||||||
|
# z to block_in
|
||||||
|
self.conv_in = nn.Conv2d(
|
||||||
|
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||||
|
self.mid.attn_1 = AttnBlock(block_in)
|
||||||
|
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
self.up = nn.ModuleList()
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
for _ in range(self.num_res_blocks + 1):
|
||||||
|
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||||
|
block_in = block_out
|
||||||
|
up = nn.Module()
|
||||||
|
up.block = block
|
||||||
|
up.attn = attn
|
||||||
|
if i_level != 0:
|
||||||
|
up.upsample = Upsample(block_in)
|
||||||
|
curr_res = curr_res * 2
|
||||||
|
self.up.insert(0, up) # prepend to get consistent order
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = nn.GroupNorm(
|
||||||
|
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
|
||||||
|
)
|
||||||
|
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
def forward(self, z: Tensor) -> Tensor:
|
||||||
|
z = self.post_quant_conv(z)
|
||||||
|
|
||||||
|
# get dtype for proper tracing
|
||||||
|
upscale_dtype = next(self.up.parameters()).dtype
|
||||||
|
|
||||||
|
# z to block_in
|
||||||
|
h = self.conv_in(z)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
h = self.mid.block_1(h)
|
||||||
|
h = self.mid.attn_1(h)
|
||||||
|
h = self.mid.block_2(h)
|
||||||
|
|
||||||
|
# cast to proper dtype
|
||||||
|
h = h.to(upscale_dtype)
|
||||||
|
# upsampling
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
|
h = self.up[i_level].block[i_block](h)
|
||||||
|
if len(self.up[i_level].attn) > 0:
|
||||||
|
h = self.up[i_level].attn[i_block](h)
|
||||||
|
if i_level != 0:
|
||||||
|
h = self.up[i_level].upsample(h)
|
||||||
|
|
||||||
|
# end
|
||||||
|
h = self.norm_out(h)
|
||||||
|
h = swish(h)
|
||||||
|
h = self.conv_out(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class AutoEncoder(nn.Module):
|
||||||
|
def __init__(self, params: AutoEncoderParams):
|
||||||
|
super().__init__()
|
||||||
|
self.params = params
|
||||||
|
self.encoder = Encoder(
|
||||||
|
resolution=params.resolution,
|
||||||
|
in_channels=params.in_channels,
|
||||||
|
ch=params.ch,
|
||||||
|
ch_mult=params.ch_mult,
|
||||||
|
num_res_blocks=params.num_res_blocks,
|
||||||
|
z_channels=params.z_channels,
|
||||||
|
)
|
||||||
|
self.decoder = Decoder(
|
||||||
|
resolution=params.resolution,
|
||||||
|
in_channels=params.in_channels,
|
||||||
|
ch=params.ch,
|
||||||
|
out_ch=params.out_ch,
|
||||||
|
ch_mult=params.ch_mult,
|
||||||
|
num_res_blocks=params.num_res_blocks,
|
||||||
|
z_channels=params.z_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bn_eps = 1e-4
|
||||||
|
self.bn_momentum = 0.1
|
||||||
|
self.ps = [2, 2]
|
||||||
|
self.bn = torch.nn.BatchNorm2d(
|
||||||
|
math.prod(self.ps) * params.z_channels,
|
||||||
|
eps=self.bn_eps,
|
||||||
|
momentum=self.bn_momentum,
|
||||||
|
affine=False,
|
||||||
|
track_running_stats=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return next(self.parameters()).dtype
|
||||||
|
|
||||||
|
def normalize(self, z):
|
||||||
|
self.bn.eval()
|
||||||
|
return self.bn(z)
|
||||||
|
|
||||||
|
def inv_normalize(self, z):
|
||||||
|
self.bn.eval()
|
||||||
|
s = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn_eps)
|
||||||
|
m = self.bn.running_mean.view(1, -1, 1, 1)
|
||||||
|
return z * s + m
|
||||||
|
|
||||||
|
def encode(self, x: Tensor) -> Tensor:
|
||||||
|
moments = self.encoder(x)
|
||||||
|
mean = torch.chunk(moments, 2, dim=1)[0]
|
||||||
|
|
||||||
|
z = rearrange(
|
||||||
|
mean,
|
||||||
|
"... c (i pi) (j pj) -> ... (c pi pj) i j",
|
||||||
|
pi=self.ps[0],
|
||||||
|
pj=self.ps[1],
|
||||||
|
)
|
||||||
|
z = self.normalize(z)
|
||||||
|
return z
|
||||||
|
|
||||||
|
def decode(self, z: Tensor) -> Tensor:
|
||||||
|
z = self.inv_normalize(z)
|
||||||
|
z = rearrange(
|
||||||
|
z,
|
||||||
|
"... (c pi pj) i j -> ... c (i pi) (j pj)",
|
||||||
|
pi=self.ps[0],
|
||||||
|
pj=self.ps[1],
|
||||||
|
)
|
||||||
|
dec = self.decoder(z)
|
||||||
|
return dec
|
||||||
520
extensions_built_in/diffusion_models/flux2/src/model.py
Normal file
520
extensions_built_in/diffusion_models/flux2/src/model.py
Normal file
@@ -0,0 +1,520 @@
|
|||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import Tensor, nn
|
||||||
|
import torch.utils.checkpoint as ckpt
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Flux2Params:
|
||||||
|
in_channels: int = 128
|
||||||
|
context_in_dim: int = 15360
|
||||||
|
hidden_size: int = 6144
|
||||||
|
num_heads: int = 48
|
||||||
|
depth: int = 8
|
||||||
|
depth_single_blocks: int = 48
|
||||||
|
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
|
||||||
|
theta: int = 2000
|
||||||
|
mlp_ratio: float = 3.0
|
||||||
|
|
||||||
|
|
||||||
|
class FakeConfig:
|
||||||
|
# for diffusers compatability
|
||||||
|
def __init__(self):
|
||||||
|
self.patch_size = 1
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2(nn.Module):
|
||||||
|
def __init__(self, params: Flux2Params):
|
||||||
|
super().__init__()
|
||||||
|
self.config = FakeConfig()
|
||||||
|
|
||||||
|
self.in_channels = params.in_channels
|
||||||
|
self.out_channels = params.in_channels
|
||||||
|
if params.hidden_size % params.num_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||||
|
)
|
||||||
|
pe_dim = params.hidden_size // params.num_heads
|
||||||
|
if sum(params.axes_dim) != pe_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
|
||||||
|
)
|
||||||
|
self.hidden_size = params.hidden_size
|
||||||
|
self.num_heads = params.num_heads
|
||||||
|
self.pe_embedder = EmbedND(
|
||||||
|
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
|
||||||
|
)
|
||||||
|
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=False)
|
||||||
|
self.time_in = MLPEmbedder(
|
||||||
|
in_dim=256, hidden_dim=self.hidden_size, disable_bias=True
|
||||||
|
)
|
||||||
|
self.guidance_in = MLPEmbedder(
|
||||||
|
in_dim=256, hidden_dim=self.hidden_size, disable_bias=True
|
||||||
|
)
|
||||||
|
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False)
|
||||||
|
|
||||||
|
self.double_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DoubleStreamBlock(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_heads,
|
||||||
|
mlp_ratio=params.mlp_ratio,
|
||||||
|
)
|
||||||
|
for _ in range(params.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.single_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
SingleStreamBlock(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_heads,
|
||||||
|
mlp_ratio=params.mlp_ratio,
|
||||||
|
)
|
||||||
|
for _ in range(params.depth_single_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.double_stream_modulation_img = Modulation(
|
||||||
|
self.hidden_size,
|
||||||
|
double=True,
|
||||||
|
disable_bias=True,
|
||||||
|
)
|
||||||
|
self.double_stream_modulation_txt = Modulation(
|
||||||
|
self.hidden_size,
|
||||||
|
double=True,
|
||||||
|
disable_bias=True,
|
||||||
|
)
|
||||||
|
self.single_stream_modulation = Modulation(
|
||||||
|
self.hidden_size, double=False, disable_bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_layer = LastLayer(
|
||||||
|
self.hidden_size,
|
||||||
|
self.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return next(self.parameters()).dtype
|
||||||
|
|
||||||
|
def enable_gradient_checkpointing(self):
|
||||||
|
self.gradient_checkpointing = True
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
x_ids: Tensor,
|
||||||
|
timesteps: Tensor,
|
||||||
|
ctx: Tensor,
|
||||||
|
ctx_ids: Tensor,
|
||||||
|
guidance: Tensor,
|
||||||
|
):
|
||||||
|
num_txt_tokens = ctx.shape[1]
|
||||||
|
|
||||||
|
timestep_emb = timestep_embedding(timesteps, 256)
|
||||||
|
vec = self.time_in(timestep_emb)
|
||||||
|
guidance_emb = timestep_embedding(guidance, 256)
|
||||||
|
vec = vec + self.guidance_in(guidance_emb)
|
||||||
|
|
||||||
|
double_block_mod_img = self.double_stream_modulation_img(vec)
|
||||||
|
double_block_mod_txt = self.double_stream_modulation_txt(vec)
|
||||||
|
single_block_mod, _ = self.single_stream_modulation(vec)
|
||||||
|
|
||||||
|
img = self.img_in(x)
|
||||||
|
txt = self.txt_in(ctx)
|
||||||
|
|
||||||
|
pe_x = self.pe_embedder(x_ids)
|
||||||
|
pe_ctx = self.pe_embedder(ctx_ids)
|
||||||
|
|
||||||
|
for block in self.double_blocks:
|
||||||
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||||
|
img.requires_grad_(True)
|
||||||
|
txt.requires_grad_(True)
|
||||||
|
img, txt = ckpt.checkpoint(
|
||||||
|
block,
|
||||||
|
img,
|
||||||
|
txt,
|
||||||
|
pe_x,
|
||||||
|
pe_ctx,
|
||||||
|
double_block_mod_img,
|
||||||
|
double_block_mod_txt,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
img, txt = block(
|
||||||
|
img,
|
||||||
|
txt,
|
||||||
|
pe_x,
|
||||||
|
pe_ctx,
|
||||||
|
double_block_mod_img,
|
||||||
|
double_block_mod_txt,
|
||||||
|
)
|
||||||
|
|
||||||
|
img = torch.cat((txt, img), dim=1)
|
||||||
|
pe = torch.cat((pe_ctx, pe_x), dim=2)
|
||||||
|
|
||||||
|
for i, block in enumerate(self.single_blocks):
|
||||||
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||||
|
img.requires_grad_(True)
|
||||||
|
img = ckpt.checkpoint(
|
||||||
|
block,
|
||||||
|
img,
|
||||||
|
pe,
|
||||||
|
single_block_mod,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
img = block(
|
||||||
|
img,
|
||||||
|
pe,
|
||||||
|
single_block_mod,
|
||||||
|
)
|
||||||
|
|
||||||
|
img = img[:, num_txt_tokens:, ...]
|
||||||
|
|
||||||
|
img = self.final_layer(img, vec)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int = 8,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||||
|
|
||||||
|
self.norm = QKNorm(head_dim)
|
||||||
|
self.proj = nn.Linear(dim, dim, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class SiLUActivation(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_fn = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
x1, x2 = x.chunk(2, dim=-1)
|
||||||
|
return self.gate_fn(x1) * x2
|
||||||
|
|
||||||
|
|
||||||
|
class Modulation(nn.Module):
|
||||||
|
def __init__(self, dim: int, double: bool, disable_bias: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.is_double = double
|
||||||
|
self.multiplier = 6 if double else 3
|
||||||
|
self.lin = nn.Linear(dim, self.multiplier * dim, bias=not disable_bias)
|
||||||
|
|
||||||
|
def forward(self, vec: torch.Tensor):
|
||||||
|
out = self.lin(nn.functional.silu(vec))
|
||||||
|
if out.ndim == 2:
|
||||||
|
out = out[:, None, :]
|
||||||
|
out = out.chunk(self.multiplier, dim=-1)
|
||||||
|
return out[:3], out[3:] if self.is_double else None
|
||||||
|
|
||||||
|
|
||||||
|
class LastLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
out_channels: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.linear = nn.Linear(hidden_size, out_channels, bias=False)
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
|
||||||
|
mod = self.adaLN_modulation(vec)
|
||||||
|
shift, scale = mod.chunk(2, dim=-1)
|
||||||
|
if shift.ndim == 2:
|
||||||
|
shift = shift[:, None, :]
|
||||||
|
scale = scale[:, None, :]
|
||||||
|
x = (1 + scale) * self.norm_final(x) + shift
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SingleStreamBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_dim = hidden_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = hidden_size // num_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
self.mlp_mult_factor = 2
|
||||||
|
|
||||||
|
self.linear1 = nn.Linear(
|
||||||
|
hidden_size,
|
||||||
|
hidden_size * 3 + self.mlp_hidden_dim * self.mlp_mult_factor,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.linear2 = nn.Linear(
|
||||||
|
hidden_size + self.mlp_hidden_dim, hidden_size, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm = QKNorm(head_dim)
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
self.mlp_act = SiLUActivation()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
pe: Tensor,
|
||||||
|
mod: tuple[Tensor, Tensor],
|
||||||
|
) -> Tensor:
|
||||||
|
mod_shift, mod_scale, mod_gate = mod
|
||||||
|
x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift
|
||||||
|
|
||||||
|
qkv, mlp = torch.split(
|
||||||
|
self.linear1(x_mod),
|
||||||
|
[3 * self.hidden_size, self.mlp_hidden_dim * self.mlp_mult_factor],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
|
q, k = self.norm(q, k, v)
|
||||||
|
|
||||||
|
attn = attention(q, k, v, pe)
|
||||||
|
|
||||||
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
|
return x + mod_gate * output
|
||||||
|
|
||||||
|
|
||||||
|
class DoubleStreamBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
assert hidden_size % num_heads == 0, (
|
||||||
|
f"{hidden_size=} must be divisible by {num_heads=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.mlp_mult_factor = 2
|
||||||
|
|
||||||
|
self.img_attn = SelfAttention(
|
||||||
|
dim=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.img_mlp = nn.Sequential(
|
||||||
|
nn.Linear(hidden_size, mlp_hidden_dim * self.mlp_mult_factor, bias=False),
|
||||||
|
SiLUActivation(),
|
||||||
|
nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.txt_attn = SelfAttention(
|
||||||
|
dim=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.txt_mlp = nn.Sequential(
|
||||||
|
nn.Linear(
|
||||||
|
hidden_size,
|
||||||
|
mlp_hidden_dim * self.mlp_mult_factor,
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
SiLUActivation(),
|
||||||
|
nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
img: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
pe: Tensor,
|
||||||
|
pe_ctx: Tensor,
|
||||||
|
mod_img: tuple[Tensor, Tensor],
|
||||||
|
mod_txt: tuple[Tensor, Tensor],
|
||||||
|
) -> tuple[Tensor, Tensor]:
|
||||||
|
img_mod1, img_mod2 = mod_img
|
||||||
|
txt_mod1, txt_mod2 = mod_txt
|
||||||
|
|
||||||
|
img_mod1_shift, img_mod1_scale, img_mod1_gate = img_mod1
|
||||||
|
img_mod2_shift, img_mod2_scale, img_mod2_gate = img_mod2
|
||||||
|
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate = txt_mod1
|
||||||
|
txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = txt_mod2
|
||||||
|
|
||||||
|
# prepare image for attention
|
||||||
|
img_modulated = self.img_norm1(img)
|
||||||
|
img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift
|
||||||
|
|
||||||
|
img_qkv = self.img_attn.qkv(img_modulated)
|
||||||
|
img_q, img_k, img_v = rearrange(
|
||||||
|
img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
||||||
|
)
|
||||||
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||||
|
|
||||||
|
# prepare txt for attention
|
||||||
|
txt_modulated = self.txt_norm1(txt)
|
||||||
|
txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift
|
||||||
|
|
||||||
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||||
|
txt_q, txt_k, txt_v = rearrange(
|
||||||
|
txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
||||||
|
)
|
||||||
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
|
|
||||||
|
q = torch.cat((txt_q, img_q), dim=2)
|
||||||
|
k = torch.cat((txt_k, img_k), dim=2)
|
||||||
|
v = torch.cat((txt_v, img_v), dim=2)
|
||||||
|
|
||||||
|
pe = torch.cat((pe_ctx, pe), dim=2)
|
||||||
|
attn = attention(q, k, v, pe)
|
||||||
|
txt_attn, img_attn = attn[:, : txt_q.shape[2]], attn[:, txt_q.shape[2] :]
|
||||||
|
|
||||||
|
# calculate the img blocks
|
||||||
|
img = img + img_mod1_gate * self.img_attn.proj(img_attn)
|
||||||
|
img = img + img_mod2_gate * self.img_mlp(
|
||||||
|
(1 + img_mod2_scale) * (self.img_norm2(img)) + img_mod2_shift
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate the txt blocks
|
||||||
|
txt = txt + txt_mod1_gate * self.txt_attn.proj(txt_attn)
|
||||||
|
txt = txt + txt_mod2_gate * self.txt_mlp(
|
||||||
|
(1 + txt_mod2_scale) * (self.txt_norm2(txt)) + txt_mod2_shift
|
||||||
|
)
|
||||||
|
return img, txt
|
||||||
|
|
||||||
|
|
||||||
|
class MLPEmbedder(nn.Module):
|
||||||
|
def __init__(self, in_dim: int, hidden_dim: int, disable_bias: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=not disable_bias)
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=not disable_bias)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return self.out_layer(self.silu(self.in_layer(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedND(nn.Module):
|
||||||
|
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dim = axes_dim
|
||||||
|
|
||||||
|
def forward(self, ids: Tensor) -> Tensor:
|
||||||
|
emb = torch.cat(
|
||||||
|
[
|
||||||
|
rope(ids[..., i], self.axes_dim[i], self.theta)
|
||||||
|
for i in range(len(self.axes_dim))
|
||||||
|
],
|
||||||
|
dim=-3,
|
||||||
|
)
|
||||||
|
|
||||||
|
return emb.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||||
|
These may be fractional.
|
||||||
|
:param dim: the dimension of the output.
|
||||||
|
:param max_period: controls the minimum frequency of the embeddings.
|
||||||
|
:return: an (N, D) Tensor of positional embeddings.
|
||||||
|
"""
|
||||||
|
t = time_factor * t
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(
|
||||||
|
-math.log(max_period)
|
||||||
|
* torch.arange(start=0, end=half, device=t.device, dtype=torch.float32)
|
||||||
|
/ half
|
||||||
|
)
|
||||||
|
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
if torch.is_floating_point(t):
|
||||||
|
embedding = embedding.to(t)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
x_dtype = x.dtype
|
||||||
|
x = x.float()
|
||||||
|
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||||
|
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||||
|
|
||||||
|
|
||||||
|
class QKNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.query_norm = RMSNorm(dim)
|
||||||
|
self.key_norm = RMSNorm(dim)
|
||||||
|
|
||||||
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
||||||
|
q = self.query_norm(q)
|
||||||
|
k = self.key_norm(k)
|
||||||
|
return q.to(v), k.to(v)
|
||||||
|
|
||||||
|
|
||||||
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||||
|
q, k = apply_rope(q, k, pe)
|
||||||
|
|
||||||
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||||
|
x = rearrange(x, "B H L D -> B L (H D)")
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||||
|
assert dim % 2 == 0
|
||||||
|
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
|
||||||
|
omega = 1.0 / (theta**scale)
|
||||||
|
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||||
|
out = torch.stack(
|
||||||
|
[torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
|
||||||
|
)
|
||||||
|
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||||
|
return out.float()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||||
|
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||||
|
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||||
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||||
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||||
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||||
358
extensions_built_in/diffusion_models/flux2/src/pipeline.py
Normal file
358
extensions_built_in/diffusion_models/flux2/src/pipeline.py
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import PIL.Image
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
|
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||||
|
from diffusers.utils import (
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
from diffusers.utils.torch_utils import randn_tensor
|
||||||
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||||
|
|
||||||
|
from diffusers.utils import BaseOutput
|
||||||
|
|
||||||
|
from .autoencoder import AutoEncoder
|
||||||
|
from .model import Flux2
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||||
|
|
||||||
|
from .sampling import (
|
||||||
|
get_schedule,
|
||||||
|
batched_prc_img,
|
||||||
|
batched_prc_txt,
|
||||||
|
encode_image_refs,
|
||||||
|
scatter_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Flux2ImagePipelineOutput(BaseOutput):
|
||||||
|
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object
|
||||||
|
attribution and actions without speculation."""
|
||||||
|
OUTPUT_LAYERS = [10, 20, 30]
|
||||||
|
MAX_LENGTH = 512
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2Pipeline(DiffusionPipeline):
|
||||||
|
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||||
|
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||||
|
vae: AutoEncoder,
|
||||||
|
text_encoder: Mistral3ForConditionalGeneration,
|
||||||
|
tokenizer: AutoProcessor,
|
||||||
|
transformer: Flux2,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.register_modules(
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
transformer=transformer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
self.vae_scale_factor = 16 # 8x plus 2x pixel shuffle
|
||||||
|
self.num_channels_latents = 128
|
||||||
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||||
|
self.default_sample_size = 64
|
||||||
|
|
||||||
|
def format_input(
|
||||||
|
self,
|
||||||
|
txt: list[str],
|
||||||
|
) -> list[list[dict]]:
|
||||||
|
# Remove [IMG] tokens from prompts to avoid Pixtral validation issues
|
||||||
|
# when truncation is enabled. The processor counts [IMG] tokens and fails
|
||||||
|
# if the count changes after truncation.
|
||||||
|
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in txt]
|
||||||
|
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [{"type": "text", "text": SYSTEM_MESSAGE}],
|
||||||
|
},
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||||
|
]
|
||||||
|
for prompt in cleaned_txt
|
||||||
|
]
|
||||||
|
|
||||||
|
def _get_mistral_prompt_embeds(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
):
|
||||||
|
device = device or self._execution_device
|
||||||
|
dtype = dtype or self.text_encoder.dtype
|
||||||
|
|
||||||
|
if not isinstance(prompt, list):
|
||||||
|
prompt = [prompt]
|
||||||
|
|
||||||
|
# Format input messages
|
||||||
|
messages_batch = self.format_input(txt=prompt)
|
||||||
|
|
||||||
|
# Process all messages at once
|
||||||
|
# with image processing a too short max length can throw an error in here.
|
||||||
|
try:
|
||||||
|
inputs = self.tokenizer.apply_chat_template(
|
||||||
|
messages_batch,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
print(
|
||||||
|
f"Error processing input: {e}, your max length is probably too short, when you have images in the input."
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# Move to device
|
||||||
|
input_ids = inputs["input_ids"].to(device)
|
||||||
|
attention_mask = inputs["attention_mask"].to(device)
|
||||||
|
|
||||||
|
# Forward pass through the model
|
||||||
|
output = self.text_encoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS], dim=1)
|
||||||
|
prompt_embeds = rearrange(out, "b c l d -> b l (c d)")
|
||||||
|
|
||||||
|
# they don't return attention mask, so we create it here
|
||||||
|
return prompt_embeds, None
|
||||||
|
|
||||||
|
def encode_prompt(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
num_images_per_prompt: int = 1,
|
||||||
|
prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
):
|
||||||
|
device = device or self._execution_device
|
||||||
|
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
if prompt_embeds is None:
|
||||||
|
prompt_embeds, prompt_embeds_mask = self._get_mistral_prompt_embeds(
|
||||||
|
prompt, device, max_sequence_length=max_sequence_length
|
||||||
|
)
|
||||||
|
|
||||||
|
_, seq_len, _ = prompt_embeds.shape
|
||||||
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||||
|
prompt_embeds = prompt_embeds.view(
|
||||||
|
batch_size * num_images_per_prompt, seq_len, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt_embeds, prompt_embeds_mask
|
||||||
|
|
||||||
|
def prepare_latents(
|
||||||
|
self,
|
||||||
|
batch_size,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents=None,
|
||||||
|
):
|
||||||
|
height = int(height) // self.vae_scale_factor
|
||||||
|
width = int(width) // self.vae_scale_factor
|
||||||
|
|
||||||
|
shape = (batch_size, num_channels_latents, height, width)
|
||||||
|
|
||||||
|
if latents is not None:
|
||||||
|
return latents.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if isinstance(generator, list) and len(generator) != batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||||
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||||
|
)
|
||||||
|
|
||||||
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@property
|
||||||
|
def guidance_scale(self):
|
||||||
|
return self._guidance_scale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_timesteps(self):
|
||||||
|
return self._num_timesteps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_timestep(self):
|
||||||
|
return self._current_timestep
|
||||||
|
|
||||||
|
@property
|
||||||
|
def interrupt(self):
|
||||||
|
return self._interrupt
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
guidance_scale: Optional[float] = None,
|
||||||
|
num_images_per_prompt: int = 1,
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.Tensor] = None,
|
||||||
|
prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_type: Optional[str] = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
control_img_list: Optional[List[PIL.Image.Image]] = None,
|
||||||
|
):
|
||||||
|
height = height or self.default_sample_size * self.vae_scale_factor
|
||||||
|
width = width or self.default_sample_size * self.vae_scale_factor
|
||||||
|
|
||||||
|
self._guidance_scale = guidance_scale
|
||||||
|
self._current_timestep = None
|
||||||
|
self._interrupt = False
|
||||||
|
|
||||||
|
# 2. Define call parameters
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
device = self._execution_device
|
||||||
|
|
||||||
|
# 3. Encode the prompt
|
||||||
|
|
||||||
|
prompt_embeds, _ = self.encode_prompt(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
prompt_embeds_mask=prompt_embeds_mask,
|
||||||
|
device=device,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
txt, txt_ids = batched_prc_txt(prompt_embeds)
|
||||||
|
|
||||||
|
# 4. Prepare latent variables\
|
||||||
|
latents = self.prepare_latents(
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
self.num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
prompt_embeds.dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
packed_latents, img_ids = batched_prc_img(latents)
|
||||||
|
|
||||||
|
timesteps = get_schedule(num_inference_steps, packed_latents.shape[1])
|
||||||
|
|
||||||
|
self._num_timesteps = len(timesteps)
|
||||||
|
|
||||||
|
guidance_vec = torch.full(
|
||||||
|
(packed_latents.shape[0],),
|
||||||
|
guidance_scale,
|
||||||
|
device=packed_latents.device,
|
||||||
|
dtype=packed_latents.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if control_img_list is not None and len(control_img_list) > 0:
|
||||||
|
img_cond_seq, img_cond_seq_ids = encode_image_refs(
|
||||||
|
self.vae, control_img_list
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
img_cond_seq, img_cond_seq_ids = None, None
|
||||||
|
|
||||||
|
# 6. Denoising loop
|
||||||
|
i = 0
|
||||||
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
||||||
|
if self.interrupt:
|
||||||
|
continue
|
||||||
|
t_vec = torch.full(
|
||||||
|
(packed_latents.shape[0],),
|
||||||
|
t_curr,
|
||||||
|
dtype=packed_latents.dtype,
|
||||||
|
device=packed_latents.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._current_timestep = t_curr
|
||||||
|
img_input = packed_latents
|
||||||
|
img_input_ids = img_ids
|
||||||
|
|
||||||
|
if img_cond_seq is not None:
|
||||||
|
assert img_cond_seq_ids is not None, (
|
||||||
|
"You need to provide either both or neither of the sequence conditioning"
|
||||||
|
)
|
||||||
|
img_input = torch.cat((img_input, img_cond_seq), dim=1)
|
||||||
|
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
|
||||||
|
|
||||||
|
pred = self.transformer(
|
||||||
|
x=img_input,
|
||||||
|
x_ids=img_input_ids,
|
||||||
|
timesteps=t_vec,
|
||||||
|
ctx=txt,
|
||||||
|
ctx_ids=txt_ids,
|
||||||
|
guidance=guidance_vec,
|
||||||
|
)
|
||||||
|
|
||||||
|
if img_cond_seq is not None:
|
||||||
|
pred = pred[:, : packed_latents.shape[1]]
|
||||||
|
|
||||||
|
packed_latents = packed_latents + (t_prev - t_curr) * pred
|
||||||
|
i += 1
|
||||||
|
progress_bar.update(1)
|
||||||
|
|
||||||
|
self._current_timestep = None
|
||||||
|
|
||||||
|
# 7. Post-processing
|
||||||
|
latents = torch.cat(scatter_ids(packed_latents, img_ids)).squeeze(2)
|
||||||
|
|
||||||
|
if output_type == "latent":
|
||||||
|
image = latents
|
||||||
|
else:
|
||||||
|
latents = latents.to(self.vae.dtype)
|
||||||
|
image = self.vae.decode(latents).float()
|
||||||
|
|
||||||
|
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||||
|
|
||||||
|
# Offload all models
|
||||||
|
self.maybe_free_model_hooks()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (image,)
|
||||||
|
|
||||||
|
return Flux2ImagePipelineOutput(images=image)
|
||||||
365
extensions_built_in/diffusion_models/flux2/src/sampling.py
Normal file
365
extensions_built_in/diffusion_models/flux2/src/sampling.py
Normal file
@@ -0,0 +1,365 @@
|
|||||||
|
import math
|
||||||
|
from typing import Callable, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from PIL import Image
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from .model import Flux2
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
|
||||||
|
def compress_time(t_ids: Tensor) -> Tensor:
|
||||||
|
assert t_ids.ndim == 1
|
||||||
|
t_ids_max = torch.max(t_ids)
|
||||||
|
t_remap = torch.zeros((t_ids_max + 1,), device=t_ids.device, dtype=t_ids.dtype)
|
||||||
|
t_unique_sorted_ids = torch.unique(t_ids, sorted=True)
|
||||||
|
t_remap[t_unique_sorted_ids] = torch.arange(
|
||||||
|
len(t_unique_sorted_ids), device=t_ids.device, dtype=t_ids.dtype
|
||||||
|
)
|
||||||
|
t_ids_compressed = t_remap[t_ids]
|
||||||
|
return t_ids_compressed
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_ids(x: Tensor, x_ids: Tensor) -> list[Tensor]:
|
||||||
|
"""
|
||||||
|
using position ids to scatter tokens into place
|
||||||
|
"""
|
||||||
|
x_list = []
|
||||||
|
t_coords = []
|
||||||
|
for data, pos in zip(x, x_ids):
|
||||||
|
_, ch = data.shape # noqa: F841
|
||||||
|
t_ids = pos[:, 0].to(torch.int64)
|
||||||
|
h_ids = pos[:, 1].to(torch.int64)
|
||||||
|
w_ids = pos[:, 2].to(torch.int64)
|
||||||
|
|
||||||
|
t_ids_cmpr = compress_time(t_ids)
|
||||||
|
|
||||||
|
t = torch.max(t_ids_cmpr) + 1
|
||||||
|
h = torch.max(h_ids) + 1
|
||||||
|
w = torch.max(w_ids) + 1
|
||||||
|
|
||||||
|
flat_ids = t_ids_cmpr * w * h + h_ids * w + w_ids
|
||||||
|
|
||||||
|
out = torch.zeros((t * h * w, ch), device=data.device, dtype=data.dtype)
|
||||||
|
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
|
||||||
|
|
||||||
|
x_list.append(rearrange(out, "(t h w) c -> 1 c t h w", t=t, h=h, w=w))
|
||||||
|
t_coords.append(torch.unique(t_ids, sorted=True))
|
||||||
|
return x_list
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image_refs(
|
||||||
|
ae,
|
||||||
|
img_ctx: Union[list[Image.Image], list[torch.Tensor]],
|
||||||
|
scale=10,
|
||||||
|
limit_pixels=1024**2,
|
||||||
|
):
|
||||||
|
if not img_ctx:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
img_ctx_prep = default_prep(img=img_ctx, limit_pixels=limit_pixels)
|
||||||
|
if not isinstance(img_ctx_prep, list):
|
||||||
|
img_ctx_prep = [img_ctx_prep]
|
||||||
|
|
||||||
|
# Encode each reference image
|
||||||
|
encoded_refs = []
|
||||||
|
for img in img_ctx_prep:
|
||||||
|
if img.ndim == 3:
|
||||||
|
img = img.unsqueeze(0)
|
||||||
|
encoded = ae.encode(img.to(ae.device, ae.dtype))[0]
|
||||||
|
encoded_refs.append(encoded)
|
||||||
|
|
||||||
|
# Create time offsets for each reference
|
||||||
|
t_off = [scale + scale * t for t in torch.arange(0, len(encoded_refs))]
|
||||||
|
t_off = [t.view(-1) for t in t_off]
|
||||||
|
|
||||||
|
# Process with position IDs
|
||||||
|
ref_tokens, ref_ids = listed_prc_img(encoded_refs, t_coord=t_off)
|
||||||
|
|
||||||
|
# Concatenate all references along sequence dimension
|
||||||
|
ref_tokens = torch.cat(ref_tokens, dim=0) # (total_ref_tokens, C)
|
||||||
|
ref_ids = torch.cat(ref_ids, dim=0) # (total_ref_tokens, 4)
|
||||||
|
|
||||||
|
# Add batch dimension
|
||||||
|
ref_tokens = ref_tokens.unsqueeze(0) # (1, total_ref_tokens, C)
|
||||||
|
ref_ids = ref_ids.unsqueeze(0) # (1, total_ref_tokens, 4)
|
||||||
|
|
||||||
|
return ref_tokens.to(torch.bfloat16), ref_ids
|
||||||
|
|
||||||
|
|
||||||
|
def prc_txt(
|
||||||
|
x: Tensor, t_coord: Tensor | None = None, l_coord: Tensor | None = None
|
||||||
|
) -> tuple[Tensor, Tensor]:
|
||||||
|
assert l_coord is None, "l_coord not supported for txts"
|
||||||
|
|
||||||
|
_l, _ = x.shape # noqa: F841
|
||||||
|
|
||||||
|
coords = {
|
||||||
|
"t": torch.arange(1) if t_coord is None else t_coord,
|
||||||
|
"h": torch.arange(1), # dummy dimension
|
||||||
|
"w": torch.arange(1), # dummy dimension
|
||||||
|
"l": torch.arange(_l),
|
||||||
|
}
|
||||||
|
x_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"])
|
||||||
|
return x, x_ids.to(x.device)
|
||||||
|
|
||||||
|
|
||||||
|
def batched_wrapper(fn):
|
||||||
|
def batched_prc(
|
||||||
|
x: Tensor, t_coord: Tensor | None = None, l_coord: Tensor | None = None
|
||||||
|
) -> tuple[Tensor, Tensor]:
|
||||||
|
results = []
|
||||||
|
for i in range(len(x)):
|
||||||
|
results.append(
|
||||||
|
fn(
|
||||||
|
x[i],
|
||||||
|
t_coord[i] if t_coord is not None else None,
|
||||||
|
l_coord[i] if l_coord is not None else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
x, x_ids = zip(*results)
|
||||||
|
return torch.stack(x), torch.stack(x_ids)
|
||||||
|
|
||||||
|
return batched_prc
|
||||||
|
|
||||||
|
|
||||||
|
def listed_wrapper(fn):
|
||||||
|
def listed_prc(
|
||||||
|
x: list[Tensor],
|
||||||
|
t_coord: list[Tensor] | None = None,
|
||||||
|
l_coord: list[Tensor] | None = None,
|
||||||
|
) -> tuple[list[Tensor], list[Tensor]]:
|
||||||
|
results = []
|
||||||
|
for i in range(len(x)):
|
||||||
|
results.append(
|
||||||
|
fn(
|
||||||
|
x[i],
|
||||||
|
t_coord[i] if t_coord is not None else None,
|
||||||
|
l_coord[i] if l_coord is not None else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
x, x_ids = zip(*results)
|
||||||
|
return list(x), list(x_ids)
|
||||||
|
|
||||||
|
return listed_prc
|
||||||
|
|
||||||
|
|
||||||
|
def prc_img(
|
||||||
|
x: Tensor, t_coord: Tensor | None = None, l_coord: Tensor | None = None
|
||||||
|
) -> tuple[Tensor, Tensor]:
|
||||||
|
c, h, w = x.shape # noqa: F841
|
||||||
|
x_coords = {
|
||||||
|
"t": torch.arange(1) if t_coord is None else t_coord,
|
||||||
|
"h": torch.arange(h),
|
||||||
|
"w": torch.arange(w),
|
||||||
|
"l": torch.arange(1) if l_coord is None else l_coord,
|
||||||
|
}
|
||||||
|
x_ids = torch.cartesian_prod(
|
||||||
|
x_coords["t"], x_coords["h"], x_coords["w"], x_coords["l"]
|
||||||
|
)
|
||||||
|
x = rearrange(x, "c h w -> (h w) c")
|
||||||
|
return x, x_ids.to(x.device)
|
||||||
|
|
||||||
|
|
||||||
|
listed_prc_img = listed_wrapper(prc_img)
|
||||||
|
batched_prc_img = batched_wrapper(prc_img)
|
||||||
|
batched_prc_txt = batched_wrapper(prc_txt)
|
||||||
|
|
||||||
|
|
||||||
|
def center_crop_to_multiple_of_x(
|
||||||
|
img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor], x: int
|
||||||
|
) -> Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor]:
|
||||||
|
if isinstance(img, list):
|
||||||
|
return [center_crop_to_multiple_of_x(_img, x) for _img in img] # type: ignore
|
||||||
|
|
||||||
|
if isinstance(img, torch.Tensor):
|
||||||
|
h, w = img.shape[-2], img.shape[-1]
|
||||||
|
else:
|
||||||
|
w, h = img.size
|
||||||
|
new_w = (w // x) * x
|
||||||
|
new_h = (h // x) * x
|
||||||
|
|
||||||
|
left = (w - new_w) // 2
|
||||||
|
top = (h - new_h) // 2
|
||||||
|
right = left + new_w
|
||||||
|
bottom = top + new_h
|
||||||
|
|
||||||
|
if isinstance(img, torch.Tensor):
|
||||||
|
return img[..., top:bottom, left:right]
|
||||||
|
resized = img.crop((left, top, right, bottom))
|
||||||
|
return resized
|
||||||
|
|
||||||
|
|
||||||
|
def cap_pixels(
|
||||||
|
img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor], k
|
||||||
|
):
|
||||||
|
if isinstance(img, list):
|
||||||
|
return [cap_pixels(_img, k) for _img in img]
|
||||||
|
if isinstance(img, torch.Tensor):
|
||||||
|
h, w = img.shape[-2], img.shape[-1]
|
||||||
|
else:
|
||||||
|
w, h = img.size
|
||||||
|
pixel_count = w * h
|
||||||
|
|
||||||
|
if pixel_count <= k:
|
||||||
|
return img
|
||||||
|
|
||||||
|
# Scaling factor to reduce total pixels below K
|
||||||
|
scale = math.sqrt(k / pixel_count)
|
||||||
|
new_w = int(w * scale)
|
||||||
|
new_h = int(h * scale)
|
||||||
|
|
||||||
|
if isinstance(img, torch.Tensor):
|
||||||
|
did_expand = False
|
||||||
|
if img.ndim == 3:
|
||||||
|
img = img.unsqueeze(0)
|
||||||
|
did_expand = True
|
||||||
|
img = torch.nn.functional.interpolate(
|
||||||
|
img,
|
||||||
|
size=(new_h, new_w),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
if did_expand:
|
||||||
|
img = img.squeeze(0)
|
||||||
|
return img
|
||||||
|
return img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
|
|
||||||
|
def cap_min_pixels(
|
||||||
|
img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor],
|
||||||
|
max_ar=8,
|
||||||
|
min_sidelength=64,
|
||||||
|
):
|
||||||
|
if isinstance(img, list):
|
||||||
|
return [
|
||||||
|
cap_min_pixels(_img, max_ar=max_ar, min_sidelength=min_sidelength)
|
||||||
|
for _img in img
|
||||||
|
]
|
||||||
|
if isinstance(img, torch.Tensor):
|
||||||
|
h, w = img.shape[-2], img.shape[-1]
|
||||||
|
else:
|
||||||
|
w, h = img.size
|
||||||
|
if w < min_sidelength or h < min_sidelength:
|
||||||
|
raise ValueError(
|
||||||
|
f"Skipping due to minimal sidelength underschritten h {h} w {w}"
|
||||||
|
)
|
||||||
|
if w / h > max_ar or h / w > max_ar:
|
||||||
|
raise ValueError(f"Skipping due to maximal ar overschritten h {h} w {w}")
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def to_rgb(
|
||||||
|
img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor],
|
||||||
|
) -> Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor]:
|
||||||
|
if isinstance(img, list):
|
||||||
|
return [
|
||||||
|
to_rgb(
|
||||||
|
_img,
|
||||||
|
)
|
||||||
|
for _img in img
|
||||||
|
]
|
||||||
|
if isinstance(img, torch.Tensor):
|
||||||
|
return img # assume already in tensor format
|
||||||
|
return img.convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
def default_images_prep(
|
||||||
|
x: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor],
|
||||||
|
) -> torch.Tensor | list[torch.Tensor]:
|
||||||
|
if isinstance(x, list):
|
||||||
|
return [default_images_prep(e) for e in x] # type: ignore
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
return x # assume already in tensor format
|
||||||
|
x_tensor = torchvision.transforms.ToTensor()(x)
|
||||||
|
return 2 * x_tensor - 1
|
||||||
|
|
||||||
|
|
||||||
|
def default_prep(
|
||||||
|
img: Image.Image | list[Image.Image] | torch.Tensor | list[torch.Tensor],
|
||||||
|
limit_pixels: int,
|
||||||
|
ensure_multiple: int = 16,
|
||||||
|
) -> torch.Tensor | list[torch.Tensor]:
|
||||||
|
# if passing a tensor, assume it is -1 to 1 already
|
||||||
|
img_rgb = to_rgb(img)
|
||||||
|
img_min = cap_min_pixels(img_rgb) # type: ignore
|
||||||
|
img_cap = cap_pixels(img_min, limit_pixels) # type: ignore
|
||||||
|
img_crop = center_crop_to_multiple_of_x(img_cap, ensure_multiple) # type: ignore
|
||||||
|
img_tensor = default_images_prep(img_crop)
|
||||||
|
return img_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def time_shift(mu: float, sigma: float, t: Tensor):
|
||||||
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||||
|
|
||||||
|
|
||||||
|
def get_lin_function(
|
||||||
|
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
||||||
|
) -> Callable[[float], float]:
|
||||||
|
m = (y2 - y1) / (x2 - x1)
|
||||||
|
b = y1 - m * x1
|
||||||
|
return lambda x: m * x + b
|
||||||
|
|
||||||
|
|
||||||
|
def get_schedule(
|
||||||
|
num_steps: int,
|
||||||
|
image_seq_len: int,
|
||||||
|
base_shift: float = 0.5,
|
||||||
|
max_shift: float = 1.15,
|
||||||
|
shift: bool = True,
|
||||||
|
) -> list[float]:
|
||||||
|
# extra step for zero
|
||||||
|
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||||
|
|
||||||
|
# shifting the schedule to favor high timesteps for higher signal images
|
||||||
|
if shift:
|
||||||
|
# estimate mu based on linear estimation between two points
|
||||||
|
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||||
|
timesteps = time_shift(mu, 1.0, timesteps)
|
||||||
|
|
||||||
|
return timesteps.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def denoise(
|
||||||
|
model: Flux2,
|
||||||
|
# model input
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
# sampling parameters
|
||||||
|
timesteps: list[float],
|
||||||
|
guidance: float,
|
||||||
|
# extra img tokens (sequence-wise)
|
||||||
|
img_cond_seq: Tensor | None = None,
|
||||||
|
img_cond_seq_ids: Tensor | None = None,
|
||||||
|
):
|
||||||
|
guidance_vec = torch.full(
|
||||||
|
(img.shape[0],), guidance, device=img.device, dtype=img.dtype
|
||||||
|
)
|
||||||
|
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
||||||
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||||
|
img_input = img
|
||||||
|
img_input_ids = img_ids
|
||||||
|
if img_cond_seq is not None:
|
||||||
|
assert img_cond_seq_ids is not None, (
|
||||||
|
"You need to provide either both or neither of the sequence conditioning"
|
||||||
|
)
|
||||||
|
img_input = torch.cat((img_input, img_cond_seq), dim=1)
|
||||||
|
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
|
||||||
|
pred = model(
|
||||||
|
x=img_input,
|
||||||
|
x_ids=img_input_ids,
|
||||||
|
timesteps=t_vec,
|
||||||
|
ctx=txt,
|
||||||
|
ctx_ids=txt_ids,
|
||||||
|
guidance=guidance_vec,
|
||||||
|
)
|
||||||
|
if img_input_ids is not None:
|
||||||
|
pred = pred[:, : img.shape[1]]
|
||||||
|
|
||||||
|
img = img + (t_prev - t_curr) * pred
|
||||||
|
|
||||||
|
return img
|
||||||
@@ -258,7 +258,13 @@ export const modelArchs: ModelArch[] = [
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
disableSections: ['network.conv'],
|
disableSections: ['network.conv'],
|
||||||
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'model.multistage', 'model.layer_offloading'],
|
additionalSections: [
|
||||||
|
'sample.ctrl_img',
|
||||||
|
'datasets.num_frames',
|
||||||
|
'model.low_vram',
|
||||||
|
'model.multistage',
|
||||||
|
'model.layer_offloading',
|
||||||
|
],
|
||||||
accuracyRecoveryAdapters: {
|
accuracyRecoveryAdapters: {
|
||||||
'4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_i2v_torchao_uint4.safetensors',
|
'4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_i2v_torchao_uint4.safetensors',
|
||||||
},
|
},
|
||||||
@@ -459,6 +465,37 @@ export const modelArchs: ModelArch[] = [
|
|||||||
disableSections: ['network.conv'],
|
disableSections: ['network.conv'],
|
||||||
additionalSections: ['datasets.control_path', 'sample.ctrl_img'],
|
additionalSections: ['datasets.control_path', 'sample.ctrl_img'],
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: 'flux2',
|
||||||
|
label: 'FLUX.2',
|
||||||
|
group: 'image',
|
||||||
|
defaults: {
|
||||||
|
// default updates when [selected, unselected] in the UI
|
||||||
|
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.2-dev', defaultNameOrPath],
|
||||||
|
'config.process[0].model.quantize': [true, false],
|
||||||
|
'config.process[0].model.quantize_te': [true, false],
|
||||||
|
'config.process[0].model.low_vram': [true, false],
|
||||||
|
'config.process[0].train.unload_text_encoder': [false, false],
|
||||||
|
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||||
|
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||||
|
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||||
|
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
|
||||||
|
'config.process[0].model.model_kwargs': [
|
||||||
|
{
|
||||||
|
match_target_res: false,
|
||||||
|
},
|
||||||
|
{},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
disableSections: ['network.conv'],
|
||||||
|
additionalSections: [
|
||||||
|
'datasets.multi_control_paths',
|
||||||
|
'sample.multi_ctrl_imgs',
|
||||||
|
'model.low_vram',
|
||||||
|
'model.layer_offloading',
|
||||||
|
'model.qie.match_target_res',
|
||||||
|
],
|
||||||
|
},
|
||||||
].sort((a, b) => {
|
].sort((a, b) => {
|
||||||
// Sort by label, case-insensitive
|
// Sort by label, case-insensitive
|
||||||
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' });
|
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' });
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
VERSION = "0.7.4"
|
VERSION = "0.7.5"
|
||||||
Reference in New Issue
Block a user