Merge pull request #278 from ostris/hidream

Add Hidream support
This commit is contained in:
Jaret Burkett
2025-04-16 13:48:27 -06:00
committed by GitHub
23 changed files with 3629 additions and 19 deletions

View File

@@ -0,0 +1,112 @@
# HiDream training is still highly experimental. The settings here will take ~35.2GB of vram to train.
# It is not possible to train on a single 24GB card yet, but I am working on it. If you have more VRAM
# I highly recommend first disabling quantization on the model itself if you can. You can leave the TEs quantized.
# HiDream has a mixture of experts that may take special training considerations that I do not
# have implemented properly. The current implementation seems to work well for LoRA training, but
# may not be effective for longer training runs. The implementation could change in future updates
# so your results may vary when this happens.
---
job: extension
config:
# this name will be the folder and filename name
name: "my_first_hidream_lora_v1"
process:
- type: 'sd_trainer'
# root folder to save training sessions/samples/weights
training_folder: "output"
# uncomment to see performance stats in the terminal every N steps
# performance_log_every: 1000
device: cuda:0
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
# trigger_word: "p3r5on"
network:
type: "lora"
linear: 32
linear_alpha: 32
network_kwargs:
# it is probably best to ignore the mixture of experts since only 2 are active each block. It works activating it, but I wouldnt.
# proper training of it is not fully implemented
ignore_if_contains:
- "ff_i.experts"
- "ff_i.gate"
save:
dtype: bfloat16 # precision to save
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
datasets:
# datasets are a folder of images. captions need to be txt files with the same name as the image
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
# images will automatically be resized and bucketed into the resolution specified
# on windows, escape back slashes with another backslash so
# "C:\\path\\to\\images\\folder"
- folder_path: "/path/to/images/folder"
caption_ext: "txt"
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
resolution: [ 512, 768, 1024 ] # hidream enjoys multiple resolutions
train:
batch_size: 1
steps: 3000 # total number of steps to train 500 - 4000 is a good range
gradient_accumulation_steps: 1
train_unet: true
train_text_encoder: false # wont work with hidream
gradient_checkpointing: true # need the on unless you have a ton of vram
noise_scheduler: "flowmatch" # for training only
timestep_type: shift # sigmoid, shift, linear
optimizer: "adamw8bit"
lr: 2e-4
# uncomment this to skip the pre training sample
# skip_first_sample: true
# uncomment to completely disable sampling
# disable_sampling: true
# uncomment to use new vell curved weighting. Experimental but may produce better results
# linear_timesteps: true
# ema will smooth out learning, but could slow it down. Defaults off
ema_config:
use_ema: false
ema_decay: 0.99
# will probably need this if gpu supports it for hidream, other dtypes may not work correctly
dtype: bf16
model:
# the transformer will get grabbed from this hf repo
# warning ONLY train on Full. The dev and fast models are distilled and will break
name_or_path: "HiDream-ai/HiDream-I1-Full"
# the extras will be grabbed from this hf repo. (text encoder, vae)
extras_name_or_path: "HiDream-ai/HiDream-I1-Full"
arch: "hidream"
# both need to be quantized to train on 48GB currently
quantize: true
quantize_te: true
model_kwargs:
# llama is a gated model, It defaults to unsloth version, but you can set the llama path here
llama_model_path: "unsloth/Meta-Llama-3.1-8B-Instruct"
sample:
sampler: "flowmatch" # must match train.noise_scheduler
sample_every: 250 # sample every this many steps
width: 1024
height: 1024
prompts:
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
- "woman with red hair, playing chess at the park, bomb going off in the background"
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
- "a bear building a log cabin in the snow covered mountains"
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
- "hipster man with a beard, building a chair, in a wood shop"
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
- "a man holding a sign that says, 'this is a sign'"
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
neg: ""
seed: 42
walk_seed: true
guidance_scale: 4
sample_steps: 25
# you can add any additional meta info here. [name] is replaced with config name at top
meta:
name: "[name]"
version: '1.0'

View File

@@ -1,4 +1,4 @@
FROM nvidia/cuda:12.6.3-base-ubuntu22.04
FROM nvidia/cuda:12.6.3-devel-ubuntu22.04
LABEL authors="jaret"
@@ -58,7 +58,8 @@ RUN echo "Cache bust: ${CACHEBUST}" && \
WORKDIR /app/ai-toolkit
# Install Python dependencies
RUN pip install --no-cache-dir -r requirements.txt
RUN pip install --no-cache-dir -r requirements.txt && \
pip install flash-attn --no-build-isolation --no-cache-dir
# Build UI
WORKDIR /app/ai-toolkit/ui

View File

@@ -1,6 +1,7 @@
from .chroma import ChromaModel
from .hidream import HidreamModel
AI_TOOLKIT_MODELS = [
# put a list of models here
ChromaModel
ChromaModel, HidreamModel
]

View File

@@ -0,0 +1 @@
from .hidream_model import HidreamModel

View File

@@ -0,0 +1,445 @@
import os
from typing import TYPE_CHECKING, List, Optional
import einops
import torch
import torchvision
import yaml
from toolkit import train_tools
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from PIL import Image
from toolkit.models.base_model import BaseModel
from diffusers import AutoencoderKL, TorchAoConfig
from toolkit.basic import flush
from toolkit.prompt_utils import PromptEmbeds
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance
from toolkit.dequantize import patch_dequantization_on_save
from toolkit.accelerator import get_accelerator, unwrap_model
from optimum.quanto import freeze, QTensor
from toolkit.util.mask import generate_random_mask, random_dialate_mask
from toolkit.util.quantize import quantize, get_qtype
from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer, TorchAoConfig as TorchAoConfigTransformers
from .src.pipelines.hidream_image.pipeline_hidream_image import HiDreamImagePipeline
from .src.models.transformers.transformer_hidream_image import HiDreamImageTransformer2DModel
from .src.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
from einops import rearrange, repeat
import random
import torch.nn.functional as F
from tqdm import tqdm
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
T5Tokenizer,
LlamaForCausalLM,
PreTrainedTokenizerFast
)
if TYPE_CHECKING:
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
scheduler_config = {
"num_train_timesteps": 1000,
"shift": 3.0
}
# LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
LLAMA_MODEL_PATH = "unsloth/Meta-Llama-3.1-8B-Instruct"
BASE_MODEL_PATH = "HiDream-ai/HiDream-I1-Full"
class HidreamModel(BaseModel):
arch = "hidream"
def __init__(
self,
device,
model_config: ModelConfig,
dtype='bf16',
custom_pipeline=None,
noise_scheduler=None,
**kwargs
):
super().__init__(
device,
model_config,
dtype,
custom_pipeline,
noise_scheduler,
**kwargs
)
self.is_flow_matching = True
self.is_transformer = True
self.target_lora_modules = ['HiDreamImageTransformer2DModel']
# static method to get the noise scheduler
@staticmethod
def get_train_scheduler():
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
def get_bucket_divisibility(self):
return 16
def load_model(self):
dtype = self.torch_dtype
# HiDream-ai/HiDream-I1-Full
self.print_and_status_update("Loading HiDream model")
# will be updated if we detect a existing checkpoint in training folder
model_path = self.model_config.name_or_path
extras_path = self.model_config.extras_name_or_path
llama_model_path = self.model_config.model_kwargs.get('llama_model_path', LLAMA_MODEL_PATH)
scheduler = HidreamModel.get_train_scheduler()
self.print_and_status_update("Loading llama 8b model")
tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(
llama_model_path,
use_fast=False
)
text_encoder_4 = LlamaForCausalLM.from_pretrained(
llama_model_path,
output_hidden_states=True,
output_attentions=True,
torch_dtype=torch.bfloat16,
)
text_encoder_4.to(self.device_torch, dtype=dtype)
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing llama 8b model")
quantization_type = get_qtype(self.model_config.qtype_te)
quantize(text_encoder_4, weights=quantization_type)
freeze(text_encoder_4)
if self.low_vram:
# unload it for now
text_encoder_4.to('cpu')
flush()
self.print_and_status_update("Loading transformer")
transformer = HiDreamImageTransformer2DModel.from_pretrained(
model_path,
subfolder="transformer",
torch_dtype=torch.bfloat16
)
if not self.low_vram:
transformer.to(self.device_torch, dtype=dtype)
if self.model_config.quantize:
self.print_and_status_update("Quantizing transformer")
quantization_type = get_qtype(self.model_config.qtype)
if self.low_vram:
# move and quantize only certain pieces at a time.
all_blocks = list(transformer.double_stream_blocks) + list(transformer.single_stream_blocks)
self.print_and_status_update(" - quantizing transformer blocks")
for block in tqdm(all_blocks):
block.to(self.device_torch, dtype=dtype)
quantize(block, weights=quantization_type)
freeze(block)
block.to('cpu')
# flush()
self.print_and_status_update(" - quantizing extras")
transformer.to(self.device_torch, dtype=dtype)
quantize(transformer, weights=quantization_type)
freeze(transformer)
else:
quantize(transformer, weights=quantization_type)
freeze(transformer)
if self.low_vram:
# unload it for now
transformer.to('cpu')
flush()
self.print_and_status_update("Loading vae")
vae = AutoencoderKL.from_pretrained(
extras_path,
subfolder="vae",
torch_dtype=torch.bfloat16
).to(self.device_torch, dtype=dtype)
self.print_and_status_update("Loading clip encoders")
text_encoder = CLIPTextModelWithProjection.from_pretrained(
extras_path,
subfolder="text_encoder",
torch_dtype=torch.bfloat16
).to(self.device_torch, dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(
extras_path,
subfolder="tokenizer"
)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
extras_path,
subfolder="text_encoder_2",
torch_dtype=torch.bfloat16
).to(self.device_torch, dtype=dtype)
tokenizer_2 = CLIPTokenizer.from_pretrained(
extras_path,
subfolder="tokenizer_2"
)
flush()
self.print_and_status_update("Loading T5 encoders")
text_encoder_3 = T5EncoderModel.from_pretrained(
extras_path,
subfolder="text_encoder_3",
torch_dtype=torch.bfloat16
).to(self.device_torch, dtype=dtype)
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing T5")
quantization_type = get_qtype(self.model_config.qtype_te)
quantize(text_encoder_3, weights=quantization_type)
freeze(text_encoder_3)
flush()
tokenizer_3 = T5Tokenizer.from_pretrained(
extras_path,
subfolder="tokenizer_3"
)
flush()
if self.low_vram:
self.print_and_status_update("Moving ecerything to device")
# move it all back
transformer.to(self.device_torch, dtype=dtype)
vae.to(self.device_torch, dtype=dtype)
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder_2.to(self.device_torch, dtype=dtype)
text_encoder_4.to(self.device_torch, dtype=dtype)
text_encoder_3.to(self.device_torch, dtype=dtype)
# set to eval mode
# transformer.eval()
vae.eval()
text_encoder.eval()
text_encoder_2.eval()
text_encoder_4.eval()
text_encoder_3.eval()
pipe = HiDreamImagePipeline(
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
text_encoder_3=text_encoder_3,
tokenizer_3=tokenizer_3,
text_encoder_4=text_encoder_4,
tokenizer_4=tokenizer_4,
transformer=transformer,
)
flush()
text_encoder_list = [text_encoder, text_encoder_2, text_encoder_3, text_encoder_4]
tokenizer_list = [tokenizer, tokenizer_2, tokenizer_3, tokenizer_4]
for te in text_encoder_list:
# set the dtype
te.to(self.device_torch, dtype=dtype)
# freeze the model
freeze(te)
# set to eval mode
te.eval()
# set the requires grad to false
te.requires_grad_(False)
flush()
# save it to the model class
self.vae = vae
self.text_encoder = text_encoder_list # list of text encoders
self.tokenizer = tokenizer_list # list of tokenizers
self.model = pipe.transformer
self.pipeline = pipe
self.print_and_status_update("Model Loaded")
def get_generation_pipeline(self):
scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=1000,
shift=3.0,
use_dynamic_shifting=False
)
pipeline: HiDreamImagePipeline = HiDreamImagePipeline(
scheduler=scheduler,
vae=self.vae,
text_encoder=self.text_encoder[0],
tokenizer=self.tokenizer[0],
text_encoder_2=self.text_encoder[1],
tokenizer_2=self.tokenizer[1],
text_encoder_3=self.text_encoder[2],
tokenizer_3=self.tokenizer[2],
text_encoder_4=self.text_encoder[3],
tokenizer_4=self.tokenizer[3],
transformer=unwrap_model(self.model),
aggressive_unloading=self.low_vram
)
pipeline = pipeline.to(self.device_torch)
return pipeline
def generate_single_image(
self,
pipeline: HiDreamImagePipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
img = pipeline(
prompt_embeds=conditional_embeds.text_embeds,
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
negative_prompt_embeds=unconditional_embeds.text_embeds,
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
generator=generator,
**extra
).images[0]
return img
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
**kwargs
):
with torch.no_grad():
if latent_model_input.shape[-2] != latent_model_input.shape[-1]:
B, C, H, W = latent_model_input.shape
pH, pW = H // self.model.config.patch_size, W // self.model.config.patch_size
img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1)
img_ids = torch.zeros(pH, pW, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :]
img_ids = img_ids.reshape(pH * pW, -1)
img_ids_pad = torch.zeros(self.transformer.max_seq, 3)
img_ids_pad[:pH*pW, :] = img_ids
img_sizes = img_sizes.unsqueeze(0).to(latent_model_input.device)
img_ids = img_ids_pad.unsqueeze(0).to(latent_model_input.device)
else:
img_sizes = img_ids = None
dtype = self.model.dtype
device = self.device_torch
# Pack the latent
if latent_model_input.shape[-2] != latent_model_input.shape[-1]:
B, C, H, W = latent_model_input.shape
patch_size = self.transformer.config.patch_size
pH, pW = H // patch_size, W // patch_size
out = torch.zeros(
(B, C, self.transformer.max_seq, patch_size * patch_size),
dtype=latent_model_input.dtype,
device=latent_model_input.device
)
latent_model_input = einops.rearrange(latent_model_input, 'B C (H p1) (W p2) -> B C (H W) (p1 p2)', p1=patch_size, p2=patch_size)
out[:, :, 0:pH*pW] = latent_model_input
latent_model_input = out
text_embeds = text_embeddings.text_embeds
# run the to for the list
text_embeds = [te.to(device, dtype=dtype) for te in text_embeds]
noise_pred = self.transformer(
hidden_states = latent_model_input,
timesteps = timestep,
encoder_hidden_states = text_embeds,
pooled_embeds = text_embeddings.pooled_embeds.to(device, dtype=dtype),
img_sizes = img_sizes,
img_ids = img_ids,
return_dict = False,
)[0]
noise_pred = -noise_pred
return noise_pred
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
self.text_encoder_to(self.device_torch, dtype=self.torch_dtype)
max_sequence_length = 128
prompt_embeds, pooled_prompt_embeds = self.pipeline._encode_prompt(
prompt = prompt,
prompt_2 = prompt,
prompt_3 = prompt,
prompt_4 = prompt,
device = self.device_torch,
dtype = self.torch_dtype,
num_images_per_prompt = 1,
max_sequence_length = max_sequence_length,
)
pe = PromptEmbeds(
[prompt_embeds, pooled_prompt_embeds]
)
return pe
def get_model_has_grad(self):
# return from a weight if it has grad
return self.model.double_stream_blocks[0].block.attn1.to_q.weight.requires_grad
def get_te_has_grad(self):
# assume no one wants to finetune 4 text encoders.
return False
def save_model(self, output_path, meta, save_dtype):
# only save the unet
transformer: HiDreamImageTransformer2DModel = unwrap_model(self.model)
transformer.save_pretrained(
save_directory=os.path.join(output_path, 'transformer'),
safe_serialization=True,
)
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
with open(meta_path, 'w') as f:
yaml.dump(meta, f)
def get_loss_target(self, *args, **kwargs):
noise = kwargs.get('noise')
batch = kwargs.get('batch')
return (noise - batch.latents).detach()
def get_transformer_block_names(self) -> Optional[List[str]]:
return ['double_stream_blocks', 'single_stream_blocks']
def convert_lora_weights_before_save(self, state_dict):
# currently starte with transformer. but needs to start with diffusion_model. for comfyui
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("transformer.", "diffusion_model.")
new_sd[new_key] = value
return new_sd
def convert_lora_weights_before_load(self, state_dict):
# saved as diffusion_model. but needs to be transformer. for ai-toolkit
new_sd = {}
for key, value in state_dict.items():
new_key = key.replace("diffusion_model.", "transformer.")
new_sd[new_key] = value
return new_sd

View File

@@ -0,0 +1,2 @@
from .models.transformers.transformer_hidream_image import HiDreamImageTransformer2DModel
from .pipelines.hidream_image.pipeline_hidream_image import HiDreamImagePipeline

View File

@@ -0,0 +1,106 @@
import torch
from torch import nn
from typing import Optional
from diffusers.models.attention_processor import Attention
from diffusers.utils.torch_utils import maybe_allow_in_graph
@maybe_allow_in_graph
class HiDreamAttention(Attention):
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
upcast_attention: bool = False,
upcast_softmax: bool = False,
scale_qk: bool = True,
eps: float = 1e-5,
processor = None,
out_dim: int = None,
single: bool = False
):
super(Attention, self).__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.out_dim = out_dim if out_dim is not None else query_dim
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
self.sliceable_head_dim = heads
self.single = single
linear_cls = nn.Linear
self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim)
self.to_k = linear_cls(self.inner_dim, self.inner_dim)
self.to_v = linear_cls(self.inner_dim, self.inner_dim)
self.to_out = linear_cls(self.inner_dim, self.out_dim)
self.q_rms_norm = nn.RMSNorm(self.inner_dim, eps)
self.k_rms_norm = nn.RMSNorm(self.inner_dim, eps)
if not single:
self.to_q_t = linear_cls(query_dim, self.inner_dim)
self.to_k_t = linear_cls(self.inner_dim, self.inner_dim)
self.to_v_t = linear_cls(self.inner_dim, self.inner_dim)
self.to_out_t = linear_cls(self.inner_dim, self.out_dim)
self.q_rms_norm_t = nn.RMSNorm(self.inner_dim, eps)
self.k_rms_norm_t = nn.RMSNorm(self.inner_dim, eps)
self.set_processor(processor)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(
self,
norm_image_tokens: torch.FloatTensor,
image_tokens_masks: torch.FloatTensor = None,
norm_text_tokens: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
) -> torch.Tensor:
return self.processor(
self,
image_tokens = norm_image_tokens,
image_tokens_masks = image_tokens_masks,
text_tokens = norm_text_tokens,
rope = rope,
)
class FeedForwardSwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * (
(hidden_dim + multiple_of - 1) // multiple_of
)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))

View File

@@ -0,0 +1,121 @@
from typing import Optional
import torch
from .attention import HiDreamAttention
# Try to import Flash Attention first
flash_attn_available = False
try:
from flash_attn_interface import flash_attn_func
USE_FLASH_ATTN3 = True
flash_attn_available = True
except ImportError:
try:
from flash_attn import flash_attn_func
USE_FLASH_ATTN3 = False
flash_attn_available = True
except ImportError:
USE_FLASH_ATTN3 = False
flash_attn_available = False
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
if flash_attn_available:
if USE_FLASH_ATTN3:
hidden_states = flash_attn_func(query, key, value, causal=False, deterministic=False)[0]
else:
hidden_states = flash_attn_func(query, key, value, dropout_p=0., causal=False)
else:
# Use torch's scaled dot-product attention as fallback
# Reshape for torch.nn.functional.scaled_dot_product_attention which expects [batch, heads, seq_len, head_dim]
query = query.transpose(1, 2) # [batch, heads, seq_len, head_dim]
key = key.transpose(1, 2)
value = value.transpose(1, 2)
hidden_states = torch.nn.functional.scaled_dot_product_attention(
query, key, value,
attn_mask=None,
dropout_p=0.0,
is_causal=False
)
# Restore original shape
hidden_states = hidden_states.transpose(1, 2) # [batch, seq_len, heads, head_dim]
hidden_states = hidden_states.flatten(-2)
hidden_states = hidden_states.to(query.dtype)
return hidden_states
class HiDreamAttnProcessor_flashattn:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __call__(
self,
attn: HiDreamAttention,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
*args,
**kwargs,
) -> torch.FloatTensor:
dtype = image_tokens.dtype
batch_size = image_tokens.shape[0]
query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype)
key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype)
value_i = attn.to_v(image_tokens)
inner_dim = key_i.shape[-1]
head_dim = inner_dim // attn.heads
query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
if image_tokens_masks is not None:
key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1)
if not attn.single:
query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype)
key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype)
value_t = attn.to_v_t(text_tokens)
query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
num_image_tokens = query_i.shape[1]
num_text_tokens = query_t.shape[1]
query = torch.cat([query_i, query_t], dim=1)
key = torch.cat([key_i, key_t], dim=1)
value = torch.cat([value_i, value_t], dim=1)
else:
query = query_i
key = key_i
value = value_i
if query.shape[-1] == rope.shape[-3] * 2:
query, key = apply_rope(query, key, rope)
else:
query_1, query_2 = query.chunk(2, dim=-1)
key_1, key_2 = key.chunk(2, dim=-1)
query_1, key_1 = apply_rope(query_1, key_1, rope)
query = torch.cat([query_1, query_2], dim=-1)
key = torch.cat([key_1, key_2], dim=-1)
hidden_states = attention(query, key, value)
if not attn.single:
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
hidden_states_i = attn.to_out(hidden_states_i)
hidden_states_t = attn.to_out_t(hidden_states_t)
return hidden_states_i, hidden_states_t
else:
hidden_states = attn.to_out(hidden_states)
return hidden_states

View File

@@ -0,0 +1,114 @@
import torch
from torch import nn
from typing import List
from diffusers.models.embeddings import Timesteps, TimestepEmbedding
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float()
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
class EmbedND(nn.Module):
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(2)
class PatchEmbed(nn.Module):
def __init__(
self,
patch_size=2,
in_channels=4,
out_channels=1024,
):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels
self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, latent):
latent = self.proj(latent)
return latent
class PooledEmbed(nn.Module):
def __init__(self, text_emb_dim, hidden_size):
super().__init__()
self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, pooled_embed):
return self.pooled_embedder(pooled_embed)
class TimestepEmbed(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, timesteps, wdtype):
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
t_emb = self.timestep_embedder(t_emb)
return t_emb
class OutEmbed(nn.Module):
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.zeros_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, adaln_input):
shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1)
x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x = self.linear(x)
return x

View File

@@ -0,0 +1,157 @@
import math
import torch
from torch import nn
import torch.nn.functional as F
from .attention import FeedForwardSwiGLU
from torch.distributed.nn.functional import all_gather
_LOAD_BALANCING_LOSS = []
def save_load_balancing_loss(loss):
global _LOAD_BALANCING_LOSS
_LOAD_BALANCING_LOSS.append(loss)
def clear_load_balancing_loss():
global _LOAD_BALANCING_LOSS
_LOAD_BALANCING_LOSS.clear()
def get_load_balancing_loss():
global _LOAD_BALANCING_LOSS
return _LOAD_BALANCING_LOSS
def batched_load_balancing_loss():
aux_losses_arr = get_load_balancing_loss()
alpha = aux_losses_arr[0][-1]
Pi = torch.stack([ent[1] for ent in aux_losses_arr], dim=0)
fi = torch.stack([ent[2] for ent in aux_losses_arr], dim=0)
fi_list = all_gather(fi)
fi = torch.stack(fi_list, 0).mean(0)
aux_loss = (Pi * fi).sum(-1).mean() * alpha
return aux_loss
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class MoEGate(nn.Module):
def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01):
super().__init__()
self.top_k = num_activated_experts
self.n_routed_experts = num_routed_experts
self.scoring_func = 'softmax'
self.alpha = aux_loss_alpha
self.seq_aux = False
# topk selection algorithm
self.norm_topk_prob = False
self.gating_dim = embed_dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
# print(bsz, seq_len, h)
### compute gating score
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
### select top-k experts
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
### norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
# this was in original and memory leaks, not needed
# ### expert-level computation auxiliary loss
# if self.training and self.alpha > 0.0:
# scores_for_aux = scores
# aux_topk = self.top_k
# # always compute aux loss based on the naive greedy topk method
# topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
# if self.seq_aux:
# scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
# ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
# ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)
# aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha
# else:
# mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
# ce = mask_ce.float().mean(0)
# Pi = scores_for_aux.mean(0)
# fi = ce * self.n_routed_experts
# aux_loss = (Pi * fi).sum() * self.alpha
# save_load_balancing_loss((aux_loss, Pi, fi, self.alpha))
# else:
aux_loss = None
return topk_idx, topk_weight, aux_loss
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class MOEFeedForwardSwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
num_routed_experts: int,
num_activated_experts: int,
):
super().__init__()
self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2)
self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim) for i in range(num_routed_experts)])
self.gate = MoEGate(
embed_dim = dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts
)
self.num_activated_experts = num_activated_experts
def forward(self, x):
wtype = x.dtype
identity = x
orig_shape = x.shape
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
# this was in original and memory leaks, not needed
# if self.training:
# x = x.repeat_interleave(self.num_activated_experts, dim=0)
# y = torch.empty_like(x, dtype=wtype)
# for i, expert in enumerate(self.experts):
# y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
# y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
# y = y.view(*orig_shape).to(dtype=wtype)
# #y = AddAuxiliaryLoss.apply(y, aux_loss)
# else:
# y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
y = y + self.shared_experts(identity)
return y
# @torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.num_activated_experts
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
# for fp16 and other dtype
expert_cache = expert_cache.to(expert_out.dtype)
expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
return expert_cache

View File

@@ -0,0 +1,506 @@
from typing import Any, Callable, Dict, Optional, Tuple, List
import torch
import torch.nn as nn
import einops
from einops import repeat
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from ..embeddings import PatchEmbed, PooledEmbed, TimestepEmbed, EmbedND, OutEmbed
from ..attention import HiDreamAttention, FeedForwardSwiGLU
from ..attention_processor import HiDreamAttnProcessor_flashattn
from ..moe import MOEFeedForwardSwiGLU
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class TextProjection(nn.Module):
def __init__(self, in_features, hidden_size):
super().__init__()
self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
def forward(self, caption):
hidden_states = self.linear(caption)
return hidden_states
class BlockType:
TransformerBlock = 1
SingleTransformerBlock = 2
@maybe_allow_in_graph
class HiDreamImageSingleTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim, bias=True)
)
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
# 1. Attention
self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
self.attn1 = HiDreamAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
processor = HiDreamAttnProcessor_flashattn(),
single = True
)
# 3. Feed-forward
self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
if num_routed_experts > 0:
self.ff_i = MOEFeedForwardSwiGLU(
dim = dim,
hidden_dim = 4 * dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
)
else:
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim)
def forward(
self,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1)
# 1. MM-Attention
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
attn_output_i = self.attn1(
norm_image_tokens,
image_tokens_masks,
rope = rope,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
# 2. Feed-forward
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype))
image_tokens = ff_output_i + image_tokens
return image_tokens
@maybe_allow_in_graph
class HiDreamImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 12 * dim, bias=True)
)
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
# 1. Attention
self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
self.norm1_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
self.attn1 = HiDreamAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
processor = HiDreamAttnProcessor_flashattn(),
single = False
)
# 3. Feed-forward
self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
if num_routed_experts > 0:
self.ff_i = MOEFeedForwardSwiGLU(
dim = dim,
hidden_dim = 4 * dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
)
else:
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim)
self.norm3_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim)
def forward(
self,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \
self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1)
# 1. MM-Attention
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype)
norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t
attn_output_i, attn_output_t = self.attn1(
norm_image_tokens,
image_tokens_masks,
norm_text_tokens,
rope = rope,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
text_tokens = gate_msa_t * attn_output_t + text_tokens
# 2. Feed-forward
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype)
norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens)
ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens)
image_tokens = ff_output_i + image_tokens
text_tokens = ff_output_t + text_tokens
return image_tokens, text_tokens
@maybe_allow_in_graph
class HiDreamImageBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
block_type: BlockType = BlockType.TransformerBlock,
):
super().__init__()
block_classes = {
BlockType.TransformerBlock: HiDreamImageTransformerBlock,
BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock,
}
self.block = block_classes[block_type](
dim,
num_attention_heads,
attention_head_dim,
num_routed_experts,
num_activated_experts
)
def forward(
self,
image_tokens: torch.FloatTensor,
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
) -> torch.FloatTensor:
return self.block(
image_tokens,
image_tokens_masks,
text_tokens,
adaln_input,
rope,
)
class HiDreamImageTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
):
_supports_gradient_checkpointing = True
_no_split_modules = ["HiDreamImageBlock"]
@register_to_config
def __init__(
self,
patch_size: Optional[int] = None,
in_channels: int = 64,
out_channels: Optional[int] = None,
num_layers: int = 16,
num_single_layers: int = 32,
attention_head_dim: int = 128,
num_attention_heads: int = 20,
caption_channels: List[int] = None,
text_emb_dim: int = 2048,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
axes_dims_rope: Tuple[int, int] = (32, 32),
max_resolution: Tuple[int, int] = (128, 128),
llama_layers: List[int] = None,
):
super().__init__()
self.out_channels = out_channels or in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.llama_layers = llama_layers
self.t_embedder = TimestepEmbed(self.inner_dim)
self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim)
self.x_embedder = PatchEmbed(
patch_size = patch_size,
in_channels = in_channels,
out_channels = self.inner_dim,
)
self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope)
self.double_stream_blocks = nn.ModuleList(
[
HiDreamImageBlock(
dim = self.inner_dim,
num_attention_heads = self.config.num_attention_heads,
attention_head_dim = self.config.attention_head_dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
block_type = BlockType.TransformerBlock
)
for i in range(self.config.num_layers)
]
)
self.single_stream_blocks = nn.ModuleList(
[
HiDreamImageBlock(
dim = self.inner_dim,
num_attention_heads = self.config.num_attention_heads,
attention_head_dim = self.config.attention_head_dim,
num_routed_experts = num_routed_experts,
num_activated_experts = num_activated_experts,
block_type = BlockType.SingleTransformerBlock
)
for i in range(self.config.num_single_layers)
]
)
self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels)
caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
caption_projection = []
for caption_channel in caption_channels:
caption_projection.append(TextProjection(in_features = caption_channel, hidden_size = self.inner_dim))
self.caption_projection = nn.ModuleList(caption_projection)
self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
self.gradient_checkpointing = False
def expand_timesteps(self, timesteps, batch_size, device):
if not torch.is_tensor(timesteps):
is_mps = device.type == "mps"
if isinstance(timesteps, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(batch_size)
return timesteps
# the implementation on hidream during train was wrong, just use the inference one.
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
# Process all images in the batch according to their specific dimensions
x_arr = []
for i, img_size in enumerate(img_sizes):
pH, pW = img_size
x_arr.append(
einops.rearrange(
x[i, :pH*pW].reshape(1, pH, pW, -1),
'B H W (p1 p2 C) -> B C (H p1) (W p2)',
p1=self.config.patch_size, p2=self.config.patch_size
)
)
x = torch.cat(x_arr, dim=0)
return x
def patchify(self, x, max_seq, img_sizes=None):
pz2 = self.config.patch_size * self.config.patch_size
if isinstance(x, torch.Tensor):
B, C = x.shape[0], x.shape[1]
device = x.device
dtype = x.dtype
else:
B, C = len(x), x[0].shape[0]
device = x[0].device
dtype = x[0].dtype
x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
if img_sizes is not None:
for i, img_size in enumerate(img_sizes):
x_masks[i, 0:img_size[0] * img_size[1]] = 1
x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2)
elif isinstance(x, torch.Tensor):
pH, pW = x.shape[-2] // self.config.patch_size, x.shape[-1] // self.config.patch_size
x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.config.patch_size, p2=self.config.patch_size)
img_sizes = [[pH, pW]] * B
x_masks = None
else:
raise NotImplementedError
return x, x_masks, img_sizes
def forward(
self,
hidden_states: torch.Tensor,
timesteps: torch.LongTensor = None,
encoder_hidden_states: torch.Tensor = None,
pooled_embeds: torch.Tensor = None,
img_sizes: Optional[List[Tuple[int, int]]] = None,
img_ids: Optional[torch.Tensor] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
# spatial forward
batch_size = hidden_states.shape[0]
hidden_states_type = hidden_states.dtype
# 0. time
timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
timesteps = self.t_embedder(timesteps, hidden_states_type)
p_embedder = self.p_embedder(pooled_embeds)
adaln_input = timesteps + p_embedder
hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
if image_tokens_masks is None:
pH, pW = img_sizes[0]
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
hidden_states = self.x_embedder(hidden_states)
T5_encoder_hidden_states = encoder_hidden_states[0]
encoder_hidden_states = encoder_hidden_states[-1]
encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
if self.caption_projection is not None:
new_encoder_hidden_states = []
for i, enc_hidden_state in enumerate(encoder_hidden_states):
enc_hidden_state = self.caption_projection[i](enc_hidden_state)
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
new_encoder_hidden_states.append(enc_hidden_state)
encoder_hidden_states = new_encoder_hidden_states
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states.append(T5_encoder_hidden_states)
txt_ids = torch.zeros(
batch_size,
encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1],
3,
device=img_ids.device, dtype=img_ids.dtype
)
ids = torch.cat((img_ids, txt_ids), dim=1)
rope = self.pe_embedder(ids)
# 2. Blocks
block_id = 0
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
for bid, block in enumerate(self.double_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].detach()
cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
image_tokens_masks,
cur_encoder_hidden_states,
adaln_input.clone(),
rope.clone(),
)
else:
hidden_states, initial_encoder_hidden_states = block(
image_tokens = hidden_states,
image_tokens_masks = image_tokens_masks,
text_tokens = cur_encoder_hidden_states,
adaln_input = adaln_input,
rope = rope,
)
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
block_id += 1
image_tokens_seq_len = hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
hidden_states_seq_len = hidden_states.shape[1]
if image_tokens_masks is not None:
encoder_attention_mask_ones = torch.ones(
(batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
device=image_tokens_masks.device, dtype=image_tokens_masks.dtype
)
image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
for bid, block in enumerate(self.single_stream_blocks):
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id].detach()
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
image_tokens_masks,
None,
adaln_input.clone(),
rope.clone(),
)
else:
hidden_states = block(
image_tokens = hidden_states,
image_tokens_masks = image_tokens_masks,
text_tokens = None,
adaln_input = adaln_input,
rope = rope,
)
hidden_states = hidden_states[:, :hidden_states_seq_len]
block_id += 1
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
output = self.final_layer(hidden_states, adaln_input)
output = self.unpatchify(output, img_sizes, self.training)
if image_tokens_masks is not None:
image_tokens_masks = image_tokens_masks[:, :image_tokens_seq_len]
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output, image_tokens_masks)
return Transformer2DModelOutput(sample=output, mask=image_tokens_masks)

View File

@@ -0,0 +1,737 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import math
import einops
import torch
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
T5Tokenizer,
LlamaForCausalLM,
PreTrainedTokenizerFast
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin
from diffusers.models.autoencoders import AutoencoderKL
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from .pipeline_output import HiDreamImagePipelineOutput
from ...models.transformers.transformer_hidream_image import HiDreamImageTransformer2DModel
from ...schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->image_encoder->transformer->vae"
_optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
text_encoder_2: CLIPTextModelWithProjection,
tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel,
tokenizer_3: T5Tokenizer,
text_encoder_4: LlamaForCausalLM,
tokenizer_4: PreTrainedTokenizerFast,
transformer: HiDreamImageTransformer2DModel,
aggressive_unloading: bool = False,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
text_encoder_3=text_encoder_3,
text_encoder_4=text_encoder_4,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
tokenizer_3=tokenizer_3,
tokenizer_4=tokenizer_4,
scheduler=scheduler,
transformer=transformer,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
# HiDreamImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.default_sample_size = 128
self.tokenizer_4.pad_token = self.tokenizer_4.eos_token
self.aggressive_unloading = aggressive_unloading
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder_3.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer_3(
prompt,
padding="max_length",
max_length=min(max_sequence_length, self.tokenizer_3.model_max_length),
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
def _get_clip_prompt_embeds(
self,
tokenizer,
text_encoder,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=min(max_sequence_length, 218),
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(untruncated_ids[:, 218 - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {218} tokens: {removed_text}"
)
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
def _get_llama3_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder_4.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer_4(
prompt,
padding="max_length",
max_length=min(max_sequence_length, self.tokenizer_4.model_max_length),
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_4.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_4.model_max_length) - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}"
)
outputs = self.text_encoder_4(
text_input_ids.to(device),
attention_mask=attention_mask.to(device),
output_hidden_states=True,
output_attentions=True
)
prompt_embeds = outputs.hidden_states[1:]
prompt_embeds = torch.stack(prompt_embeds, dim=0)
_, _, seq_len, dim = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, 1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_3: Union[str, List[str]],
prompt_4: Union[str, List[str]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt_3: Optional[Union[str, List[str]]] = None,
negative_prompt_4: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 128,
lora_scale: Optional[float] = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds[0].shape[0]
prompt_embeds, pooled_prompt_embeds = self._encode_prompt(
prompt = prompt,
prompt_2 = prompt_2,
prompt_3 = prompt_3,
prompt_4 = prompt_4,
device = device,
dtype = dtype,
num_images_per_prompt = num_images_per_prompt,
prompt_embeds = prompt_embeds,
pooled_prompt_embeds = pooled_prompt_embeds,
max_sequence_length = max_sequence_length,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
negative_prompt_3 = negative_prompt_3 or negative_prompt
negative_prompt_4 = negative_prompt_4 or negative_prompt
# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
negative_prompt_3 = (
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
)
negative_prompt_4 = (
batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4
)
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt(
prompt = negative_prompt,
prompt_2 = negative_prompt_2,
prompt_3 = negative_prompt_3,
prompt_4 = negative_prompt_4,
device = device,
dtype = dtype,
num_images_per_prompt = num_images_per_prompt,
prompt_embeds = negative_prompt_embeds,
pooled_prompt_embeds = negative_pooled_prompt_embeds,
max_sequence_length = max_sequence_length,
)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
def _encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_3: Union[str, List[str]],
prompt_4: Union[str, List[str]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 128,
):
device = device or self._execution_device
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
prompt_3 = prompt_3 or prompt
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
prompt_4 = prompt_4 or prompt
prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4
pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
self.tokenizer,
self.text_encoder,
prompt = prompt,
num_images_per_prompt = num_images_per_prompt,
max_sequence_length = max_sequence_length,
device = device,
dtype = dtype,
)
pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
self.tokenizer_2,
self.text_encoder_2,
prompt = prompt_2,
num_images_per_prompt = num_images_per_prompt,
max_sequence_length = max_sequence_length,
device = device,
dtype = dtype,
)
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1)
t5_prompt_embeds = self._get_t5_prompt_embeds(
prompt = prompt_3,
num_images_per_prompt = num_images_per_prompt,
max_sequence_length = max_sequence_length,
device = device,
dtype = dtype
)
llama3_prompt_embeds = self._get_llama3_prompt_embeds(
prompt = prompt_4,
num_images_per_prompt = num_images_per_prompt,
max_sequence_length = max_sequence_length,
device = device,
dtype = dtype
)
prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds]
return prompt_embeds, pooled_prompt_embeds
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
prompt_3: Optional[Union[str, List[str]]] = None,
prompt_4: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt_3: Optional[Union[str, List[str]]] = None,
negative_prompt_4: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 128,
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
division = self.vae_scale_factor * 2
S_max = (self.default_sample_size * self.vae_scale_factor) ** 2
scale = S_max / (width * height)
scale = math.sqrt(scale)
width, height = int(width * scale // division * division), int(height * scale // division * division)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds[0].shape[0]
device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_3=prompt_3,
prompt_4=prompt_4,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
negative_prompt_3=negative_prompt_3,
negative_prompt_4=negative_prompt_4,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if self.do_classifier_free_guidance:
prompt_embeds_arr = []
for n, p in zip(negative_prompt_embeds, prompt_embeds):
if len(n.shape) == 3:
prompt_embeds_arr.append(torch.cat([n, p], dim=0))
else:
prompt_embeds_arr.append(torch.cat([n, p], dim=1))
prompt_embeds = prompt_embeds_arr
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
pooled_prompt_embeds.dtype,
device,
generator,
latents,
)
if latents.shape[-2] != latents.shape[-1]:
B, C, H, W = latents.shape
pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size
img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1)
img_ids = torch.zeros(pH, pW, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :]
img_ids = img_ids.reshape(pH * pW, -1)
img_ids_pad = torch.zeros(self.transformer.max_seq, 3)
img_ids_pad[:pH*pW, :] = img_ids
img_sizes = img_sizes.unsqueeze(0).to(latents.device)
img_ids = img_ids_pad.unsqueeze(0).to(latents.device)
if self.do_classifier_free_guidance:
img_sizes = img_sizes.repeat(2 * B, 1)
img_ids = img_ids.repeat(2 * B, 1, 1)
else:
img_sizes = img_ids = None
# 5. Prepare timesteps
mu = calculate_shift(self.transformer.max_seq)
scheduler_kwargs = {"mu": mu}
if isinstance(self.scheduler, FlowUniPCMultistepScheduler):
self.scheduler.set_timesteps(num_inference_steps, device=device, shift=math.exp(mu))
timesteps = self.scheduler.timesteps
else:
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
**scheduler_kwargs,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
if latent_model_input.shape[-2] != latent_model_input.shape[-1]:
B, C, H, W = latent_model_input.shape
patch_size = self.transformer.config.patch_size
pH, pW = H // patch_size, W // patch_size
out = torch.zeros(
(B, C, self.transformer.max_seq, patch_size * patch_size),
dtype=latent_model_input.dtype,
device=latent_model_input.device
)
latent_model_input = einops.rearrange(latent_model_input, 'B C (H p1) (W p2) -> B C (H W) (p1 p2)', p1=patch_size, p2=patch_size)
out[:, :, 0:pH*pW] = latent_model_input
latent_model_input = out
noise_pred = self.transformer(
hidden_states = latent_model_input,
timesteps = timestep,
encoder_hidden_states = prompt_embeds,
pooled_embeds = pooled_prompt_embeds,
img_sizes = img_sizes,
img_ids = img_ids,
return_dict = False,
)[0]
noise_pred = -noise_pred
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return HiDreamImagePipelineOutput(images=image)

View File

@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from diffusers.utils import BaseOutput
@dataclass
class HiDreamImagePipelineOutput(BaseOutput):
"""
Output class for HiDreamImage pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]

View File

@@ -0,0 +1,428 @@
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import BaseOutput, is_scipy_available, logging
from diffusers.utils.torch_utils import randn_tensor
if is_scipy_available():
import scipy.stats
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class FlashFlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class FlashFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Euler scheduler.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
shift (`float`, defaults to 1.0):
The shift value for the timestep schedule.
"""
_compatibles = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0,
use_dynamic_shifting=False,
base_shift: Optional[float] = 0.5,
max_shift: Optional[float] = 1.15,
base_image_seq_len: Optional[int] = 256,
max_image_seq_len: Optional[int] = 4096,
invert_sigmas: bool = False,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
sigmas = timesteps / num_train_timesteps
if not use_dynamic_shifting:
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.timesteps = sigmas * num_train_timesteps
self._step_index = None
self._begin_index = None
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_noise(
self,
sample: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
Forward process in flow-matching
Args:
sample (`torch.FloatTensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
if sample.device.type == "mps" and torch.is_floating_point(timestep):
# mps does not support float64
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
timestep = timestep.to(sample.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(sample.device)
timestep = timestep.to(sample.device)
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timestep.shape[0]
else:
# add noise is called before first denoising step to create initial latent(img2img)
step_indices = [self.begin_index] * timestep.shape[0]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(sample.shape):
sigma = sigma.unsqueeze(-1)
sample = sigma * noise + (1.0 - sigma) * sample
return sample
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[float] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
if self.config.use_dynamic_shifting and mu is None:
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
if sigmas is None:
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
)
sigmas = timesteps / self.config.num_train_timesteps
else:
sigmas = np.array(sigmas).astype(np.float32)
num_inference_steps = len(sigmas)
self.num_inference_steps = num_inference_steps
if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps
if self.config.invert_sigmas:
sigmas = 1.0 - sigmas
timesteps = sigmas * self.config.num_train_timesteps
sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
else:
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self.timesteps = timesteps.to(device=device)
self.sigmas = sigmas
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[FlashFlowMatchEulerDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
s_churn (`float`):
s_tmin (`float`):
s_tmax (`float`):
s_noise (`float`, defaults to 1.0):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if self.step_index is None:
self._init_step_index(timestep)
# Upcast to avoid precision issues when computing prev_sample
sigma = self.sigmas[self.step_index]
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
denoised = sample - model_output * sigma
if self.step_index < self.num_inference_steps - 1:
sigma_next = self.sigmas[self.step_index + 1]
noise = randn_tensor(
model_output.shape,
generator=generator,
device=model_output.device,
dtype=denoised.dtype,
)
sample = sigma_next * noise + (1.0 - sigma_next) * denoised
self._step_index += 1
sample = sample.to(model_output.dtype)
if not return_dict:
return (sample,)
return FlashFlowMatchEulerDiscreteSchedulerOutput(prev_sample=sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""Constructs an exponential noise schedule."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = np.array(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -0,0 +1,800 @@
# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
# Convert unipc for flow matching
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
SchedulerMixin,
SchedulerOutput)
from diffusers.utils import deprecate, is_scipy_available
if is_scipy_available():
import scipy.stats
class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
`UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
solver_order (`int`, default `2`):
The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
unconditional sampling.
prediction_type (`str`, defaults to "flow_prediction"):
Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
the flow of the diffusion process.
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
dynamic_thresholding_ratio (`float`, defaults to 0.995):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
predict_x0 (`bool`, defaults to `True`):
Whether to use the updating algorithm on the predicted x0.
solver_type (`str`, default `bh2`):
Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
otherwise.
lower_order_final (`bool`, default `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
disable_corrector (`list`, default `[]`):
Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
usually disabled during the first few steps.
solver_p (`SchedulerMixin`, default `None`):
Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families.
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
solver_order: int = 2,
prediction_type: str = "flow_prediction",
shift: Optional[float] = 1.0,
use_dynamic_shifting=False,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
predict_x0: bool = True,
solver_type: str = "bh2",
lower_order_final: bool = True,
disable_corrector: List[int] = [],
solver_p: SchedulerMixin = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
):
if solver_type not in ["bh1", "bh2"]:
if solver_type in ["midpoint", "heun", "logrho"]:
self.register_to_config(solver_type="bh2")
else:
raise NotImplementedError(
f"{solver_type} is not implemented for {self.__class__}")
self.predict_x0 = predict_x0
# setable values
self.num_inference_steps = None
alphas = np.linspace(1, 1 / num_train_timesteps,
num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
if not use_dynamic_shifting:
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
sigmas = shift * sigmas / (1 +
(shift - 1) * sigmas) # pyright: ignore
self.sigmas = sigmas
self.timesteps = sigmas * num_train_timesteps
self.model_outputs = [None] * solver_order
self.timestep_list = [None] * solver_order
self.lower_order_nums = 0
self.disable_corrector = disable_corrector
self.solver_p = solver_p
self.last_sample = None
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to(
"cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
# Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
def set_timesteps(
self,
num_inference_steps: Union[int, None] = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[Union[float, None]] = None,
shift: Optional[Union[float, None]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
Total number of the spacing of the time steps.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
if self.config.use_dynamic_shifting and mu is None:
raise ValueError(
" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
)
if sigmas is None:
sigmas = np.linspace(self.sigma_max, self.sigma_min,
num_inference_steps +
1).copy()[:-1] # pyright: ignore
if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
else:
if shift is None:
shift = self.config.shift
sigmas = shift * sigmas / (1 +
(shift - 1) * sigmas) # pyright: ignore
if self.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) /
self.alphas_cumprod[0])**0.5
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
timesteps = sigmas * self.config.num_train_timesteps
sigmas = np.concatenate([sigmas, [sigma_last]
]).astype(np.float32) # pyright: ignore
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(
device=device, dtype=torch.int64)
self.num_inference_steps = len(timesteps)
self.model_outputs = [
None,
] * self.config.solver_order
self.lower_order_nums = 0
self.last_sample = None
if self.solver_p:
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to(
"cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float(
) # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(
abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(
1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(
sample, -s, s
) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def _sigma_to_alpha_sigma_t(self, sigma):
return 1 - sigma, sigma
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
def convert_model_output(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
r"""
Convert the model output to the corresponding type the UniPC algorithm needs.
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The converted model output.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError(
"missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
if self.predict_x0:
if self.config.prediction_type == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
)
if self.config.thresholding:
x0_pred = self._threshold_sample(x0_pred)
return x0_pred
else:
if self.config.prediction_type == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
epsilon = sample - (1 - sigma_t) * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
)
if self.config.thresholding:
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
x0_pred = self._threshold_sample(x0_pred)
epsilon = model_output + x0_pred
return epsilon
def multistep_uni_p_bh_update(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
order: int = None, # pyright: ignore
**kwargs,
) -> torch.Tensor:
"""
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model at the current timestep.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
order (`int`):
The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
Returns:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
prev_timestep = args[0] if len(args) > 0 else kwargs.pop(
"prev_timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError(
" missing `sample` as a required keyward argument")
if order is None:
if len(args) > 2:
order = args[2]
else:
raise ValueError(
" missing `order` as a required keyward argument")
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
model_output_list = self.model_outputs
s0 = self.timestep_list[-1]
m0 = model_output_list[-1]
x = sample
if self.solver_p:
x_t = self.solver_p.step(model_output, s0, x).prev_sample
return x_t
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[
self.step_index] # pyright: ignore
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
h = lambda_t - lambda_s0
device = sample.device
rks = []
D1s = []
for i in range(1, order):
si = self.step_index - i # pyright: ignore
mi = model_output_list[-(i + 1)]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk) # pyright: ignore
rks.append(1.0)
rks = torch.tensor(rks, device=device)
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config.solver_type == "bh1":
B_h = hh
elif self.config.solver_type == "bh2":
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=device)
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1) # (B, K)
# for order 2, we use a simplified version
if order == 2:
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
else:
rhos_p = torch.linalg.solve(R[:-1, :-1],
b[:-1]).to(device).to(x.dtype)
else:
D1s = None
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
D1s) # pyright: ignore
else:
pred_res = 0
x_t = x_t_ - alpha_t * B_h * pred_res
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
D1s) # pyright: ignore
else:
pred_res = 0
x_t = x_t_ - sigma_t * B_h * pred_res
x_t = x_t.to(x.dtype)
return x_t
def multistep_uni_c_bh_update(
self,
this_model_output: torch.Tensor,
*args,
last_sample: torch.Tensor = None,
this_sample: torch.Tensor = None,
order: int = None, # pyright: ignore
**kwargs,
) -> torch.Tensor:
"""
One step for the UniC (B(h) version).
Args:
this_model_output (`torch.Tensor`):
The model outputs at `x_t`.
this_timestep (`int`):
The current timestep `t`.
last_sample (`torch.Tensor`):
The generated sample before the last predictor `x_{t-1}`.
this_sample (`torch.Tensor`):
The generated sample after the last predictor `x_{t}`.
order (`int`):
The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
Returns:
`torch.Tensor`:
The corrected sample tensor at the current timestep.
"""
this_timestep = args[0] if len(args) > 0 else kwargs.pop(
"this_timestep", None)
if last_sample is None:
if len(args) > 1:
last_sample = args[1]
else:
raise ValueError(
" missing`last_sample` as a required keyward argument")
if this_sample is None:
if len(args) > 2:
this_sample = args[2]
else:
raise ValueError(
" missing`this_sample` as a required keyward argument")
if order is None:
if len(args) > 3:
order = args[3]
else:
raise ValueError(
" missing`order` as a required keyward argument")
if this_timestep is not None:
deprecate(
"this_timestep",
"1.0.0",
"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
model_output_list = self.model_outputs
m0 = model_output_list[-1]
x = last_sample
x_t = this_sample
model_t = this_model_output
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[
self.step_index - 1] # pyright: ignore
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
h = lambda_t - lambda_s0
device = this_sample.device
rks = []
D1s = []
for i in range(1, order):
si = self.step_index - (i + 1) # pyright: ignore
mi = model_output_list[-(i + 1)]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk) # pyright: ignore
rks.append(1.0)
rks = torch.tensor(rks, device=device)
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config.solver_type == "bh1":
B_h = hh
elif self.config.solver_type == "bh2":
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=device)
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1)
else:
D1s = None
# for order 1, we use a simplified version
if order == 1:
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
else:
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
x_t = x_t.to(x.dtype)
return x_t
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
"""
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(self,
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
generator=None) -> Union[SchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
the multistep UniPC.
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
use_corrector = (
self.step_index > 0 and
self.step_index - 1 not in self.disable_corrector and
self.last_sample is not None # pyright: ignore
)
model_output_convert = self.convert_model_output(
model_output, sample=sample)
if use_corrector:
sample = self.multistep_uni_c_bh_update(
this_model_output=model_output_convert,
last_sample=self.last_sample,
this_sample=sample,
order=self.this_order,
)
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.timestep_list[i] = self.timestep_list[i + 1]
self.model_outputs[-1] = model_output_convert
self.timestep_list[-1] = timestep # pyright: ignore
if self.config.lower_order_final:
this_order = min(self.config.solver_order,
len(self.timesteps) -
self.step_index) # pyright: ignore
else:
this_order = self.config.solver_order
self.this_order = min(this_order,
self.lower_order_nums + 1) # warmup for multistep
assert self.this_order > 0
self.last_sample = sample
prev_sample = self.multistep_uni_p_bh_update(
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
sample=sample,
order=self.this_order,
)
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
# upon completion increase step index by one
self._step_index += 1 # pyright: ignore
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.Tensor, *args,
**kwargs) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.Tensor`):
The input sample.
Returns:
`torch.Tensor`:
A scaled input sample.
"""
return sample
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(
device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(
timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(
original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(
original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [
self.index_for_timestep(t, schedule_timesteps)
for t in timesteps
]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called before first denoising step to create initial latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -649,11 +649,13 @@ def targeted_flow_guidance(
noise,
timesteps
).detach()
unconditional_noisy_latents = sd.condition_noisy_latents(unconditional_noisy_latents, batch)
conditional_noisy_latents = sd.add_noise(
conditional_latents,
noise,
timesteps
).detach()
conditional_noisy_latents = sd.condition_noisy_latents(conditional_noisy_latents, batch)
# disable the lora to get a baseline prediction
sd.network.is_active = False

View File

@@ -725,9 +725,13 @@ class BaseModel:
do_classifier_free_guidance = True
# check if batch size of embeddings matches batch size of latents
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
if isinstance(text_embeddings.text_embeds, list):
te_batch_size = text_embeddings.text_embeds[0].shape[0]
else:
te_batch_size = text_embeddings.text_embeds.shape[0]
if latents.shape[0] == te_batch_size:
do_classifier_free_guidance = False
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
elif latents.shape[0] * 2 != te_batch_size:
raise ValueError(
"Batch size of latents must be the same or half the batch size of text embeddings")
latents = latents.to(self.device_torch)

View File

@@ -36,7 +36,10 @@ class PromptEmbeds:
self.attention_mask = attention_mask
def to(self, *args, **kwargs):
self.text_embeds = self.text_embeds.to(*args, **kwargs)
if isinstance(self.text_embeds, list) or isinstance(self.text_embeds, tuple):
self.text_embeds = [t.to(*args, **kwargs) for t in self.text_embeds]
else:
self.text_embeds = self.text_embeds.to(*args, **kwargs)
if self.pooled_embeds is not None:
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
if self.attention_mask is not None:
@@ -45,7 +48,10 @@ class PromptEmbeds:
def detach(self):
new_embeds = self.clone()
new_embeds.text_embeds = new_embeds.text_embeds.detach()
if isinstance(new_embeds.text_embeds, list) or isinstance(new_embeds.text_embeds, tuple):
new_embeds.text_embeds = [t.detach() for t in new_embeds.text_embeds]
else:
new_embeds.text_embeds = new_embeds.text_embeds.detach()
if new_embeds.pooled_embeds is not None:
new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach()
if new_embeds.attention_mask is not None:
@@ -53,10 +59,14 @@ class PromptEmbeds:
return new_embeds
def clone(self):
if self.pooled_embeds is not None:
prompt_embeds = PromptEmbeds([self.text_embeds.clone(), self.pooled_embeds.clone()])
if isinstance(self.text_embeds, list) or isinstance(self.text_embeds, tuple):
cloned_text_embeds = [t.clone() for t in self.text_embeds]
else:
prompt_embeds = PromptEmbeds(self.text_embeds.clone())
cloned_text_embeds = self.text_embeds.clone()
if self.pooled_embeds is not None:
prompt_embeds = PromptEmbeds([cloned_text_embeds, self.pooled_embeds.clone()])
else:
prompt_embeds = PromptEmbeds(cloned_text_embeds)
if self.attention_mask is not None:
prompt_embeds.attention_mask = self.attention_mask.clone()
@@ -64,12 +74,18 @@ class PromptEmbeds:
def expand_to_batch(self, batch_size):
pe = self.clone()
current_batch_size = pe.text_embeds.shape[0]
if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple):
current_batch_size = pe.text_embeds[0].shape[0]
else:
current_batch_size = pe.text_embeds.shape[0]
if current_batch_size == batch_size:
return pe
if current_batch_size != 1:
raise Exception("Can only expand batch size for batch size 1")
pe.text_embeds = pe.text_embeds.expand(batch_size, -1)
if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple):
pe.text_embeds = [t.expand(batch_size, -1) for t in pe.text_embeds]
else:
pe.text_embeds = pe.text_embeds.expand(batch_size, -1)
if pe.pooled_embeds is not None:
pe.pooled_embeds = pe.pooled_embeds.expand(batch_size, -1)
if pe.attention_mask is not None:
@@ -145,7 +161,13 @@ class EncodedPromptPair:
def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
if isinstance(prompt_embeds[0].text_embeds, list) or isinstance(prompt_embeds[0].text_embeds, tuple):
embed_list = []
for i in range(len(prompt_embeds[0].text_embeds)):
embed_list.append(torch.cat([p.text_embeds[i] for p in prompt_embeds], dim=0))
text_embeds = embed_list
else:
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
pooled_embeds = None
if prompt_embeds[0].pooled_embeds is not None:
pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0)
@@ -196,7 +218,16 @@ def split_prompt_embeds(concatenated: PromptEmbeds, num_parts=None) -> List[Prom
if num_parts is None:
# use batch size
num_parts = concatenated.text_embeds.shape[0]
text_embeds_splits = torch.chunk(concatenated.text_embeds, num_parts, dim=0)
if isinstance(concatenated.text_embeds, list) or isinstance(concatenated.text_embeds, tuple):
# split each part
text_embeds_splits = [
torch.chunk(text, num_parts, dim=0)
for text in concatenated.text_embeds
]
text_embeds_splits = list(zip(*text_embeds_splits))
else:
text_embeds_splits = torch.chunk(concatenated.text_embeds, num_parts, dim=0)
if concatenated.pooled_embeds is not None:
pooled_embeds_splits = torch.chunk(concatenated.pooled_embeds, num_parts, dim=0)

View File

@@ -283,7 +283,7 @@ export default function SimpleJob({
options={[
{ value: 'sigmoid', label: 'Sigmoid' },
{ value: 'linear', label: 'Linear' },
{ value: 'flux_shift', label: 'Flux Shift' },
{ value: 'shift', label: 'Shift' },
]}
/>
<SelectInput

View File

@@ -16,7 +16,7 @@ export const defaultDatasetConfig: DatasetConfig = {
export const defaultJobConfig: JobConfig = {
job: 'extension',
config: {
name: 'my_first_flex_lora_v1',
name: 'my_first_lora_v1',
process: [
{
type: 'ui_trainer',
@@ -31,6 +31,9 @@ export const defaultJobConfig: JobConfig = {
linear_alpha: 32,
lokr_full_rank: true,
lokr_factor: -1,
network_kwargs: {
ignore_if_contains: [],
},
},
save: {
dtype: 'bf16',
@@ -43,7 +46,7 @@ export const defaultJobConfig: JobConfig = {
train: {
batch_size: 1,
bypass_guidance_embedding: true,
steps: 2000,
steps: 3000,
gradient_accumulation: 1,
train_unet: true,
train_text_encoder: false,
@@ -58,7 +61,7 @@ export const defaultJobConfig: JobConfig = {
unload_text_encoder: false,
lr: 0.0001,
ema_config: {
use_ema: true,
use_ema: false,
ema_decay: 0.99,
},
dtype: 'bf16',

View File

@@ -12,6 +12,7 @@ export const modelArchs = [
{ name: 'flux', label: 'Flux.1' },
{ name: 'wan21', label: 'Wan 2.1' },
{ name: 'lumina2', label: 'Lumina2' },
{ name: 'hidream', label: 'HiDream' },
];
export const isVideoModelFromArch = (arch: string) => {
@@ -83,6 +84,20 @@ export const options = {
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
},
{
name_or_path: 'HiDream-ai/HiDream-I1-Full',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.arch': ['hidream', defaultModelArch],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].train.lr': [0.0002, 0.0001],
'config.process[0].train.timestep_type': ['shift', 'sigmoid'],
'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []],
},
},
{
name_or_path: 'ostris/objective-reality',
dev_only: true,

View File

@@ -55,6 +55,9 @@ export interface NetworkConfig {
linear_alpha: number;
lokr_full_rank: boolean;
lokr_factor: number;
network_kwargs: {
ignore_if_contains: string[];
}
}
export interface SaveConfig {

View File

@@ -1 +1 @@
VERSION = "0.2.5"
VERSION = "0.2.6"