mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add initial support for chroma radiance
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from .chroma import ChromaModel
|
||||
from .chroma import ChromaModel, ChromaRadianceModel
|
||||
from .hidream import HidreamModel, HidreamE1Model
|
||||
from .f_light import FLiteModel
|
||||
from .omnigen2 import OmniGen2Model
|
||||
@@ -9,6 +9,7 @@ from .qwen_image import QwenImageModel, QwenImageEditModel
|
||||
AI_TOOLKIT_MODELS = [
|
||||
# put a list of models here
|
||||
ChromaModel,
|
||||
ChromaRadianceModel,
|
||||
HidreamModel,
|
||||
HidreamE1Model,
|
||||
FLiteModel,
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from .chroma_model import ChromaModel
|
||||
from .chroma_model import ChromaModel
|
||||
from .chroma_radiance_model import ChromaRadianceModel
|
||||
@@ -15,7 +15,7 @@ from toolkit.accelerator import unwrap_model
|
||||
from optimum.quanto import freeze, QTensor
|
||||
from toolkit.util.quantize import quantize, get_qtype
|
||||
from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer
|
||||
from .pipeline import ChromaPipeline
|
||||
from .pipeline import ChromaPipeline, prepare_latent_image_ids
|
||||
from einops import rearrange, repeat
|
||||
import random
|
||||
import torch.nn.functional as F
|
||||
@@ -324,12 +324,19 @@ class ChromaModel(BaseModel):
|
||||
ph=2,
|
||||
pw=2
|
||||
)
|
||||
|
||||
img_ids = prepare_latent_image_ids(
|
||||
bs,
|
||||
h,
|
||||
w,
|
||||
patch_size=2
|
||||
).to(device=self.device_torch)
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c",
|
||||
b=bs).to(self.device_torch)
|
||||
# img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
# img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
# img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
# img_ids = repeat(img_ids, "h w c -> b (h w) c",
|
||||
# b=bs).to(self.device_torch)
|
||||
|
||||
txt_ids = torch.zeros(
|
||||
bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
|
||||
|
||||
@@ -0,0 +1,445 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from PIL import Image
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.basic import flush
|
||||
from diffusers import AutoencoderKL
|
||||
# from toolkit.pixel_shuffle_encoder import AutoencoderPixelMixer
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||
from toolkit.dequantize import patch_dequantization_on_save
|
||||
from toolkit.accelerator import unwrap_model
|
||||
from optimum.quanto import freeze, QTensor
|
||||
from toolkit.util.quantize import quantize, get_qtype
|
||||
from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer
|
||||
from .pipeline import ChromaPipeline, prepare_latent_image_ids
|
||||
from einops import rearrange, repeat
|
||||
import random
|
||||
import torch.nn.functional as F
|
||||
from .src.radiance import Chroma, chroma_params
|
||||
from safetensors.torch import load_file, save_file
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.models.FakeVAE import FakeVAE
|
||||
import huggingface_hub
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
|
||||
scheduler_config = {
|
||||
"base_image_seq_len": 256,
|
||||
"base_shift": 0.5,
|
||||
"max_image_seq_len": 4096,
|
||||
"max_shift": 1.15,
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 3.0,
|
||||
"use_dynamic_shifting": True
|
||||
}
|
||||
|
||||
class FakeConfig:
|
||||
# for diffusers compatability
|
||||
def __init__(self):
|
||||
self.attention_head_dim = 128
|
||||
self.guidance_embeds = True
|
||||
self.in_channels = 64
|
||||
self.joint_attention_dim = 4096
|
||||
self.num_attention_heads = 24
|
||||
self.num_layers = 19
|
||||
self.num_single_layers = 38
|
||||
self.patch_size = 1
|
||||
|
||||
class FakeCLIP(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dtype = torch.bfloat16
|
||||
self.device = 'cuda'
|
||||
self.text_model = None
|
||||
self.tokenizer = None
|
||||
self.model_max_length = 77
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return torch.zeros(1, 1, 1).to(self.device)
|
||||
|
||||
|
||||
class ChromaRadianceModel(BaseModel):
|
||||
arch = "chroma_radiance"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype='bf16',
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
device,
|
||||
model_config,
|
||||
dtype,
|
||||
custom_pipeline,
|
||||
noise_scheduler,
|
||||
**kwargs
|
||||
)
|
||||
self.is_flow_matching = True
|
||||
self.is_transformer = True
|
||||
self.target_lora_modules = ['Chroma']
|
||||
|
||||
# static method to get the noise scheduler
|
||||
@staticmethod
|
||||
def get_train_scheduler():
|
||||
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
# return the bucket divisibility for the model
|
||||
return 32
|
||||
|
||||
def load_model(self):
|
||||
dtype = self.torch_dtype
|
||||
|
||||
# will be updated if we detect a existing checkpoint in training folder
|
||||
model_path = self.model_config.name_or_path
|
||||
|
||||
if model_path == "lodestones/Chroma":
|
||||
print("Looking for latest Chroma checkpoint")
|
||||
# get the latest checkpoint
|
||||
files_list = huggingface_hub.list_repo_files(model_path)
|
||||
print(files_list)
|
||||
latest_version = 28 # current latest version at time of writing
|
||||
while True:
|
||||
if f"chroma-unlocked-v{latest_version}.safetensors" not in files_list:
|
||||
latest_version -= 1
|
||||
break
|
||||
else:
|
||||
latest_version += 1
|
||||
print(f"Using latest Chroma version: v{latest_version}")
|
||||
|
||||
# make sure we have it
|
||||
model_path = huggingface_hub.hf_hub_download(
|
||||
repo_id=model_path,
|
||||
filename=f"chroma-unlocked-v{latest_version}.safetensors",
|
||||
)
|
||||
elif model_path.startswith("lodestones/Chroma/v"):
|
||||
# get the version number
|
||||
version = model_path.split("/")[-1].split("v")[-1]
|
||||
print(f"Using Chroma version: v{version}")
|
||||
# make sure we have it
|
||||
model_path = huggingface_hub.hf_hub_download(
|
||||
repo_id='lodestones/Chroma',
|
||||
filename=f"chroma-unlocked-v{version}.safetensors",
|
||||
)
|
||||
elif model_path.startswith("lodestones/Chroma1-"):
|
||||
# will have a file in the repo that is Chroma1-whatever.safetensors
|
||||
model_path = huggingface_hub.hf_hub_download(
|
||||
repo_id=model_path,
|
||||
filename=f"{model_path.split('/')[-1]}.safetensors",
|
||||
)
|
||||
|
||||
else:
|
||||
# check if the model path is a local file
|
||||
if os.path.exists(model_path):
|
||||
print(f"Using local model: {model_path}")
|
||||
else:
|
||||
raise ValueError(f"Model path {model_path} does not exist")
|
||||
|
||||
# extras_path = 'black-forest-labs/FLUX.1-schnell'
|
||||
# schnell model is gated now, use flex instead
|
||||
extras_path = 'ostris/Flex.1-alpha'
|
||||
|
||||
self.print_and_status_update("Loading transformer")
|
||||
|
||||
if model_path.endswith('.pth') or model_path.endswith('.pt'):
|
||||
chroma_state_dict = torch.load(model_path, map_location='cpu', weights_only=True)
|
||||
else:
|
||||
chroma_state_dict = load_file(model_path, 'cpu')
|
||||
|
||||
# determine number of double and single blocks
|
||||
double_blocks = 0
|
||||
single_blocks = 0
|
||||
for key in chroma_state_dict.keys():
|
||||
if "double_blocks" in key:
|
||||
block_num = int(key.split(".")[1]) + 1
|
||||
if block_num > double_blocks:
|
||||
double_blocks = block_num
|
||||
elif "single_blocks" in key:
|
||||
block_num = int(key.split(".")[1]) + 1
|
||||
if block_num > single_blocks:
|
||||
single_blocks = block_num
|
||||
print(f"Double Blocks: {double_blocks}")
|
||||
print(f"Single Blocks: {single_blocks}")
|
||||
|
||||
chroma_params.depth = double_blocks
|
||||
chroma_params.depth_single_blocks = single_blocks
|
||||
transformer = Chroma(chroma_params)
|
||||
|
||||
# add dtype, not sure why it doesnt have it
|
||||
transformer.dtype = dtype
|
||||
# load the state dict into the model
|
||||
transformer.load_state_dict(chroma_state_dict)
|
||||
|
||||
transformer.to(self.quantize_device, dtype=dtype)
|
||||
|
||||
transformer.config = FakeConfig()
|
||||
transformer.config.num_layers = double_blocks
|
||||
transformer.config.num_single_layers = single_blocks
|
||||
|
||||
if self.model_config.quantize:
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = get_qtype(self.model_config.qtype)
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type,
|
||||
**self.model_config.quantize_kwargs)
|
||||
freeze(transformer)
|
||||
transformer.to(self.device_torch)
|
||||
else:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Loading T5")
|
||||
tokenizer_2 = T5TokenizerFast.from_pretrained(
|
||||
extras_path, subfolder="tokenizer_2", torch_dtype=dtype
|
||||
)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(
|
||||
extras_path, subfolder="text_encoder_2", torch_dtype=dtype
|
||||
)
|
||||
text_encoder_2.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing T5")
|
||||
quantize(text_encoder_2, weights=get_qtype(
|
||||
self.model_config.qtype))
|
||||
freeze(text_encoder_2)
|
||||
flush()
|
||||
|
||||
# self.print_and_status_update("Loading CLIP")
|
||||
text_encoder = FakeCLIP()
|
||||
tokenizer = FakeCLIP()
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
|
||||
self.noise_scheduler = ChromaRadianceModel.get_train_scheduler()
|
||||
|
||||
self.print_and_status_update("Loading VAE")
|
||||
# vae = AutoencoderKL.from_pretrained(
|
||||
# extras_path,
|
||||
# subfolder="vae",
|
||||
# torch_dtype=dtype
|
||||
# )
|
||||
vae = FakeVAE()
|
||||
vae = vae.to(self.device_torch, dtype=dtype)
|
||||
|
||||
self.print_and_status_update("Making pipe")
|
||||
|
||||
pipe: ChromaPipeline = ChromaPipeline(
|
||||
scheduler=self.noise_scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=None,
|
||||
tokenizer_2=tokenizer_2,
|
||||
vae=vae,
|
||||
transformer=None,
|
||||
is_radiance=True,
|
||||
)
|
||||
# for quantization, it works best to do these after making the pipe
|
||||
pipe.text_encoder_2 = text_encoder_2
|
||||
pipe.transformer = transformer
|
||||
|
||||
self.print_and_status_update("Preparing Model")
|
||||
|
||||
text_encoder = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
|
||||
flush()
|
||||
# just to make sure everything is on the right device and dtype
|
||||
text_encoder[0].to(self.device_torch)
|
||||
text_encoder[0].requires_grad_(False)
|
||||
text_encoder[0].eval()
|
||||
text_encoder[1].to(self.device_torch)
|
||||
text_encoder[1].requires_grad_(False)
|
||||
text_encoder[1].eval()
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
flush()
|
||||
|
||||
# save it to the model class
|
||||
self.vae = vae
|
||||
self.text_encoder = text_encoder # list of text encoders
|
||||
self.tokenizer = tokenizer # list of tokenizers
|
||||
self.model = pipe.transformer
|
||||
self.pipeline = pipe
|
||||
self.print_and_status_update("Model Loaded")
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = ChromaRadianceModel.get_train_scheduler()
|
||||
pipeline = ChromaPipeline(
|
||||
scheduler=scheduler,
|
||||
text_encoder=unwrap_model(self.text_encoder[0]),
|
||||
tokenizer=self.tokenizer[0],
|
||||
text_encoder_2=unwrap_model(self.text_encoder[1]),
|
||||
tokenizer_2=self.tokenizer[1],
|
||||
vae=unwrap_model(self.vae),
|
||||
transformer=unwrap_model(self.transformer),
|
||||
is_radiance=True,
|
||||
)
|
||||
|
||||
# pipeline = pipeline.to(self.device_torch)
|
||||
|
||||
return pipeline
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline: ChromaPipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
|
||||
extra['negative_prompt_embeds'] = unconditional_embeds.text_embeds
|
||||
extra['negative_prompt_attn_mask'] = unconditional_embeds.attention_mask
|
||||
|
||||
img = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
prompt_attn_mask=conditional_embeds.attention_mask,
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
generator=generator,
|
||||
**extra
|
||||
).images[0]
|
||||
return img
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
latent_model_input: torch.Tensor,
|
||||
timestep: torch.Tensor, # 0 to 1000 scale
|
||||
text_embeddings: PromptEmbeds,
|
||||
**kwargs
|
||||
):
|
||||
with torch.no_grad():
|
||||
bs, c, h, w = latent_model_input.shape
|
||||
|
||||
img_ids = prepare_latent_image_ids(
|
||||
bs, h, w, patch_size=16
|
||||
).to(self.device_torch)
|
||||
|
||||
txt_ids = torch.zeros(
|
||||
bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
|
||||
|
||||
guidance = torch.full([1], 0, device=self.device_torch, dtype=torch.float32)
|
||||
guidance = guidance.expand(bs)
|
||||
|
||||
cast_dtype = self.unet.dtype
|
||||
|
||||
noise_pred = self.unet(
|
||||
img=latent_model_input.to(
|
||||
self.device_torch, cast_dtype
|
||||
),
|
||||
img_ids=img_ids,
|
||||
txt=text_embeddings.text_embeds.to(
|
||||
self.device_torch, cast_dtype
|
||||
),
|
||||
txt_ids=txt_ids,
|
||||
txt_mask=text_embeddings.attention_mask.to(
|
||||
self.device_torch, cast_dtype
|
||||
),
|
||||
timesteps=timestep / 1000,
|
||||
guidance=guidance
|
||||
)
|
||||
|
||||
if isinstance(noise_pred, QTensor):
|
||||
noise_pred = noise_pred.dequantize()
|
||||
|
||||
return noise_pred
|
||||
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
if isinstance(prompt, str):
|
||||
prompts = [prompt]
|
||||
else:
|
||||
prompts = prompt
|
||||
if self.pipeline.text_encoder.device != self.device_torch:
|
||||
self.pipeline.text_encoder.to(self.device_torch)
|
||||
|
||||
max_length = 512
|
||||
|
||||
device = self.text_encoder[1].device
|
||||
dtype = self.text_encoder[1].dtype
|
||||
|
||||
# T5
|
||||
text_inputs = self.tokenizer[1](
|
||||
prompts,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
prompt_embeds = self.text_encoder[1](text_input_ids.to(device), output_hidden_states=False)[0]
|
||||
|
||||
dtype = self.text_encoder[1].dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
prompt_attention_mask = text_inputs["attention_mask"]
|
||||
|
||||
pe = PromptEmbeds(
|
||||
prompt_embeds
|
||||
)
|
||||
pe.attention_mask = prompt_attention_mask
|
||||
return pe
|
||||
|
||||
def get_model_has_grad(self):
|
||||
# return from a weight if it has grad
|
||||
return False
|
||||
def get_te_has_grad(self):
|
||||
# return from a weight if it has grad
|
||||
return False
|
||||
|
||||
def save_model(self, output_path, meta, save_dtype):
|
||||
if not output_path.endswith(".safetensors"):
|
||||
output_path = output_path + ".safetensors"
|
||||
# only save the unet
|
||||
transformer: Chroma = unwrap_model(self.model)
|
||||
state_dict = transformer.state_dict()
|
||||
save_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if isinstance(v, QTensor):
|
||||
v = v.dequantize()
|
||||
save_dict[k] = v.clone().to('cpu', dtype=save_dtype)
|
||||
|
||||
meta = get_meta_for_safetensors(meta, name='chroma')
|
||||
save_file(save_dict, output_path, metadata=meta)
|
||||
|
||||
def get_loss_target(self, *args, **kwargs):
|
||||
noise = kwargs.get('noise')
|
||||
batch = kwargs.get('batch')
|
||||
return (noise - batch.latents).detach()
|
||||
|
||||
def convert_lora_weights_before_save(self, state_dict):
|
||||
# currently starte with transformer. but needs to start with diffusion_model. for comfyui
|
||||
new_sd = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = key.replace("transformer.", "diffusion_model.")
|
||||
new_sd[new_key] = value
|
||||
return new_sd
|
||||
|
||||
def convert_lora_weights_before_load(self, state_dict):
|
||||
# saved as diffusion_model. but needs to be transformer. for ai-toolkit
|
||||
new_sd = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = key.replace("diffusion_model.", "transformer.")
|
||||
new_sd[new_key] = value
|
||||
return new_sd
|
||||
|
||||
def get_base_model_version(self):
|
||||
return "chroma_radiance"
|
||||
@@ -6,6 +6,7 @@ from diffusers import FluxPipeline
|
||||
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
|
||||
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
||||
from diffusers.utils import is_torch_xla_available
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -16,7 +17,134 @@ else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
def prepare_latent_image_ids(batch_size, height, width, patch_size=2, max_offset=0):
|
||||
"""
|
||||
Generates positional embeddings for a latent image.
|
||||
|
||||
Args:
|
||||
batch_size (int): The number of images in the batch.
|
||||
height (int): The height of the image.
|
||||
width (int): The width of the image.
|
||||
patch_size (int, optional): The size of the patches. Defaults to 2.
|
||||
max_offset (int, optional): The maximum random offset to apply. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor containing the positional embeddings.
|
||||
"""
|
||||
# the random pos embedding helps generalize to larger res without training at large res
|
||||
# pos embedding for rope, 2d pos embedding, corner embedding and not center based
|
||||
latent_image_ids = torch.zeros(height // patch_size, width // patch_size, 3)
|
||||
|
||||
# Add positional encodings
|
||||
latent_image_ids[..., 1] = (
|
||||
latent_image_ids[..., 1] + torch.arange(height // patch_size)[:, None]
|
||||
)
|
||||
latent_image_ids[..., 2] = (
|
||||
latent_image_ids[..., 2] + torch.arange(width // patch_size)[None, :]
|
||||
)
|
||||
|
||||
# Add random offset if specified
|
||||
if max_offset > 0:
|
||||
offset_y = torch.randint(0, max_offset + 1, (1,)).item()
|
||||
offset_x = torch.randint(0, max_offset + 1, (1,)).item()
|
||||
latent_image_ids[..., 1] += offset_y
|
||||
latent_image_ids[..., 2] += offset_x
|
||||
|
||||
|
||||
(
|
||||
latent_image_id_height,
|
||||
latent_image_id_width,
|
||||
latent_image_id_channels,
|
||||
) = latent_image_ids.shape
|
||||
|
||||
# Reshape for batch
|
||||
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
batch_size,
|
||||
latent_image_id_height * latent_image_id_width,
|
||||
latent_image_id_channels,
|
||||
)
|
||||
|
||||
return latent_image_ids
|
||||
|
||||
|
||||
class ChromaPipeline(FluxPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler,
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
text_encoder_2,
|
||||
tokenizer_2,
|
||||
transformer,
|
||||
image_encoder = None,
|
||||
feature_extractor = None,
|
||||
is_radiance: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
transformer=transformer,
|
||||
image_encoder=image_encoder,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.is_radiance = is_radiance
|
||||
self.vae_scale_factor = 8 if not is_radiance else 1
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
latent_image_ids = prepare_latent_image_ids(
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
patch_size=2 if not self.is_radiance else 16
|
||||
).to(device=device, dtype=dtype)
|
||||
# latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
if not self.is_radiance:
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
|
||||
# latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
latent_image_ids = prepare_latent_image_ids(
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
patch_size=2 if not self.is_radiance else 16
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
return latents, latent_image_ids
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
@@ -70,6 +198,8 @@ class ChromaPipeline(FluxPipeline):
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = 64 // 4
|
||||
if self.is_radiance:
|
||||
num_channels_latents = 3
|
||||
latents, latent_image_ids = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
@@ -82,8 +212,8 @@ class ChromaPipeline(FluxPipeline):
|
||||
)
|
||||
|
||||
# extend img ids to match batch size
|
||||
latent_image_ids = latent_image_ids.unsqueeze(0)
|
||||
latent_image_ids = torch.cat([latent_image_ids] * batch_size, dim=0)
|
||||
# latent_image_ids = latent_image_ids.unsqueeze(0)
|
||||
# latent_image_ids = torch.cat([latent_image_ids] * batch_size, dim=0)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
@@ -180,8 +310,9 @@ class ChromaPipeline(FluxPipeline):
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = self._unpack_latents(
|
||||
latents, height, width, self.vae_scale_factor)
|
||||
if not self.is_radiance:
|
||||
latents = self._unpack_latents(
|
||||
latents, height, width, self.vae_scale_factor)
|
||||
latents = (latents / self.vae.config.scaling_factor) + \
|
||||
self.vae.config.shift_factor
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
@@ -7,6 +7,7 @@ from torch import Tensor, nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .math import attention, rope
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
@@ -88,7 +89,7 @@ class RMSNorm(torch.nn.Module):
|
||||
# return self._forward(x)
|
||||
|
||||
|
||||
def distribute_modulations(tensor: torch.Tensor):
|
||||
def distribute_modulations(tensor: torch.Tensor, depth_single_blocks, depth_double_blocks):
|
||||
"""
|
||||
Distributes slices of the tensor into the block_dict as ModulationOut objects.
|
||||
|
||||
@@ -102,25 +103,25 @@ def distribute_modulations(tensor: torch.Tensor):
|
||||
# HARD CODED VALUES! lookup table for the generated vectors
|
||||
# TODO: move this into chroma config!
|
||||
# Add 38 single mod blocks
|
||||
for i in range(38):
|
||||
for i in range(depth_single_blocks):
|
||||
key = f"single_blocks.{i}.modulation.lin"
|
||||
block_dict[key] = None
|
||||
|
||||
# Add 19 image double blocks
|
||||
for i in range(19):
|
||||
for i in range(depth_double_blocks):
|
||||
key = f"double_blocks.{i}.img_mod.lin"
|
||||
block_dict[key] = None
|
||||
|
||||
# Add 19 text double blocks
|
||||
for i in range(19):
|
||||
for i in range(depth_double_blocks):
|
||||
key = f"double_blocks.{i}.txt_mod.lin"
|
||||
block_dict[key] = None
|
||||
|
||||
# Add the final layer
|
||||
block_dict["final_layer.adaLN_modulation.1"] = None
|
||||
# 6.2b version
|
||||
block_dict["lite_double_blocks.4.img_mod.lin"] = None
|
||||
block_dict["lite_double_blocks.4.txt_mod.lin"] = None
|
||||
# block_dict["lite_double_blocks.4.img_mod.lin"] = None
|
||||
# block_dict["lite_double_blocks.4.txt_mod.lin"] = None
|
||||
|
||||
idx = 0 # Index to keep track of the vector slices
|
||||
|
||||
@@ -173,6 +174,219 @@ def distribute_modulations(tensor: torch.Tensor):
|
||||
return block_dict
|
||||
|
||||
|
||||
|
||||
class NerfEmbedder(nn.Module):
|
||||
"""
|
||||
An embedder module that combines input features with a 2D positional
|
||||
encoding that mimics the Discrete Cosine Transform (DCT).
|
||||
|
||||
This module takes an input tensor of shape (B, P^2, C), where P is the
|
||||
patch size, and enriches it with positional information before projecting
|
||||
it to a new hidden size.
|
||||
"""
|
||||
def __init__(self, in_channels, hidden_size_input, max_freqs):
|
||||
"""
|
||||
Initializes the NerfEmbedder.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input tensor.
|
||||
hidden_size_input (int): The desired dimension of the output embedding.
|
||||
max_freqs (int): The number of frequency components to use for both
|
||||
the x and y dimensions of the positional encoding.
|
||||
The total number of positional features will be max_freqs^2.
|
||||
"""
|
||||
super().__init__()
|
||||
self.max_freqs = max_freqs
|
||||
self.hidden_size_input = hidden_size_input
|
||||
|
||||
# A linear layer to project the concatenated input features and
|
||||
# positional encodings to the final output dimension.
|
||||
self.embedder = nn.Sequential(
|
||||
nn.Linear(in_channels + max_freqs**2, hidden_size_input)
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=4)
|
||||
def fetch_pos(self, patch_size, device, dtype):
|
||||
"""
|
||||
Generates and caches 2D DCT-like positional embeddings for a given patch size.
|
||||
|
||||
The LRU cache is a performance optimization that avoids recomputing the
|
||||
same positional grid on every forward pass.
|
||||
|
||||
Args:
|
||||
patch_size (int): The side length of the square input patch.
|
||||
device: The torch device to create the tensors on.
|
||||
dtype: The torch dtype for the tensors.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (1, patch_size^2, max_freqs^2) containing the
|
||||
positional embeddings.
|
||||
"""
|
||||
# Create normalized 1D coordinate grids from 0 to 1.
|
||||
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||
|
||||
# Create a 2D meshgrid of coordinates.
|
||||
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
|
||||
|
||||
# Reshape positions to be broadcastable with frequencies.
|
||||
# Shape becomes (patch_size^2, 1, 1).
|
||||
pos_x = pos_x.reshape(-1, 1, 1)
|
||||
pos_y = pos_y.reshape(-1, 1, 1)
|
||||
|
||||
# Create a 1D tensor of frequency values from 0 to max_freqs-1.
|
||||
freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device)
|
||||
|
||||
# Reshape frequencies to be broadcastable for creating 2D basis functions.
|
||||
# freqs_x shape: (1, max_freqs, 1)
|
||||
# freqs_y shape: (1, 1, max_freqs)
|
||||
freqs_x = freqs[None, :, None]
|
||||
freqs_y = freqs[None, None, :]
|
||||
|
||||
# A custom weighting coefficient, not part of standard DCT.
|
||||
# This seems to down-weight the contribution of higher-frequency interactions.
|
||||
coeffs = (1 + freqs_x * freqs_y) ** -1
|
||||
|
||||
# Calculate the 1D cosine basis functions for x and y coordinates.
|
||||
# This is the core of the DCT formulation.
|
||||
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
|
||||
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
|
||||
|
||||
# Combine the 1D basis functions to create 2D basis functions by element-wise
|
||||
# multiplication, and apply the custom coefficients. Broadcasting handles the
|
||||
# combination of all (pos_x, freqs_x) with all (pos_y, freqs_y).
|
||||
# The result is flattened into a feature vector for each position.
|
||||
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2)
|
||||
|
||||
return dct
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
Forward pass for the embedder.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input tensor of shape (B, P^2, C).
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor of shape (B, P^2, hidden_size_input).
|
||||
"""
|
||||
# Get the batch size, number of pixels, and number of channels.
|
||||
B, P2, C = inputs.shape
|
||||
# Store the original dtype to cast back to at the end.
|
||||
original_dtype = inputs.dtype
|
||||
# Force all operations within this module to run in fp32.
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
# Infer the patch side length from the number of pixels (P^2).
|
||||
patch_size = int(P2 ** 0.5)
|
||||
|
||||
inputs = inputs.float()
|
||||
# Fetch the pre-computed or cached positional embeddings.
|
||||
dct = self.fetch_pos(patch_size, inputs.device, torch.float32)
|
||||
|
||||
# Repeat the positional embeddings for each item in the batch.
|
||||
dct = dct.repeat(B, 1, 1)
|
||||
|
||||
# Concatenate the original input features with the positional embeddings
|
||||
# along the feature dimension.
|
||||
inputs = torch.cat([inputs, dct], dim=-1)
|
||||
|
||||
# Project the combined tensor to the target hidden size.
|
||||
inputs = self.embedder.float()(inputs)
|
||||
|
||||
return inputs.to(original_dtype)
|
||||
|
||||
|
||||
|
||||
class NerfGLUBlock(nn.Module):
|
||||
"""
|
||||
A NerfBlock using a Gated Linear Unit (GLU) like MLP.
|
||||
"""
|
||||
def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio, use_compiled):
|
||||
super().__init__()
|
||||
# The total number of parameters for the MLP is increased to accommodate
|
||||
# the gate, value, and output projection matrices.
|
||||
# We now need to generate parameters for 3 matrices.
|
||||
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
||||
self.param_generator = nn.Linear(hidden_size_s, total_params)
|
||||
self.norm = RMSNorm(hidden_size_x, use_compiled)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
# nn.init.zeros_(self.param_generator.weight)
|
||||
# nn.init.zeros_(self.param_generator.bias)
|
||||
|
||||
|
||||
def forward(self, x, s):
|
||||
batch_size, num_x, hidden_size_x = x.shape
|
||||
mlp_params = self.param_generator(s)
|
||||
|
||||
# Split the generated parameters into three parts for the gate, value, and output projection.
|
||||
fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1)
|
||||
|
||||
# Reshape the parameters into matrices for batch matrix multiplication.
|
||||
fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
|
||||
fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio)
|
||||
fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x)
|
||||
|
||||
# Normalize the generated weight matrices as in the original implementation.
|
||||
fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2)
|
||||
fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2)
|
||||
fc2 = torch.nn.functional.normalize(fc2, dim=-2)
|
||||
|
||||
res_x = x
|
||||
x = self.norm(x)
|
||||
|
||||
# Apply the final output projection.
|
||||
x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2)
|
||||
|
||||
x = x + res_x
|
||||
return x
|
||||
|
||||
|
||||
class NerfFinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels, use_compiled):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, use_compiled=use_compiled)
|
||||
self.linear = nn.Linear(hidden_size, out_channels)
|
||||
nn.init.zeros_(self.linear.weight)
|
||||
nn.init.zeros_(self.linear.bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class NerfFinalLayerConv(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels, use_compiled):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, use_compiled=use_compiled)
|
||||
|
||||
# replace nn.Linear with nn.Conv2d since linear is just pointwise conv
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=hidden_size,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=1
|
||||
)
|
||||
nn.init.zeros_(self.conv.weight)
|
||||
nn.init.zeros_(self.conv.bias)
|
||||
|
||||
def forward(self, x):
|
||||
# shape: [N, C, H, W] !
|
||||
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
|
||||
# So, we permute the dimensions to make the channel dimension the last one.
|
||||
x_permuted = x.permute(0, 2, 3, 1) # Shape becomes [N, H, W, C]
|
||||
|
||||
# Apply normalization on the feature/channel dimension
|
||||
x_norm = self.norm(x_permuted)
|
||||
|
||||
# Permute back to the original dimension order for the convolution
|
||||
x_norm_permuted = x_norm.permute(0, 3, 1, 2) # Shape becomes [N, C, H, W]
|
||||
|
||||
# Apply the 3x3 convolution
|
||||
x = self.conv(x_norm_permuted)
|
||||
return x
|
||||
|
||||
|
||||
class Approximator(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers=4):
|
||||
super().__init__()
|
||||
|
||||
@@ -156,13 +156,19 @@ class Chroma(nn.Module):
|
||||
)
|
||||
|
||||
# TODO: move this hardcoded value to config
|
||||
self.mod_index_length = 344
|
||||
# single layer has 3 modulation vectors
|
||||
# double layer has 6 modulation vectors for each expert
|
||||
# final layer has 2 modulation vectors
|
||||
self.mod_index_length = 3 * params.depth_single_blocks + 2 * 6 * params.depth + 2
|
||||
self.depth_single_blocks = params.depth_single_blocks
|
||||
self.depth_double_blocks = params.depth
|
||||
# self.mod_index = torch.tensor(list(range(self.mod_index_length)), device=0)
|
||||
self.register_buffer(
|
||||
"mod_index",
|
||||
torch.tensor(list(range(self.mod_index_length)), device="cpu"),
|
||||
persistent=False,
|
||||
)
|
||||
self.approximator_in_dim = params.approximator_in_dim
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
@@ -213,7 +219,7 @@ class Chroma(nn.Module):
|
||||
# then and only then we could concatenate it together
|
||||
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1)
|
||||
mod_vectors = self.distilled_guidance_layer(input_vec.requires_grad_(True))
|
||||
mod_vectors_dict = distribute_modulations(mod_vectors)
|
||||
mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
380
extensions_built_in/diffusion_models/chroma/src/radiance.py
Normal file
380
extensions_built_in/diffusion_models/chroma/src/radiance.py
Normal file
@@ -0,0 +1,380 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
import torch.utils.checkpoint as ckpt
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
LastLayer,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
Approximator,
|
||||
distribute_modulations,
|
||||
NerfEmbedder,
|
||||
NerfFinalLayer,
|
||||
NerfFinalLayerConv,
|
||||
NerfGLUBlock
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChromaParams:
|
||||
in_channels: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list[int]
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
approximator_in_dim: int
|
||||
approximator_depth: int
|
||||
approximator_hidden_size: int
|
||||
patch_size: int
|
||||
nerf_hidden_size: int
|
||||
nerf_mlp_ratio: int
|
||||
nerf_depth: int
|
||||
nerf_max_freqs: int
|
||||
_use_compiled: bool
|
||||
|
||||
|
||||
chroma_params = ChromaParams(
|
||||
in_channels=3,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
approximator_in_dim=64,
|
||||
approximator_depth=5,
|
||||
approximator_hidden_size=5120,
|
||||
patch_size=16,
|
||||
nerf_hidden_size=64,
|
||||
nerf_mlp_ratio=4,
|
||||
nerf_depth=4,
|
||||
nerf_max_freqs=8,
|
||||
_use_compiled=False,
|
||||
)
|
||||
|
||||
|
||||
def modify_mask_to_attend_padding(mask, max_seq_length, num_extra_padding=8):
|
||||
"""
|
||||
Modifies attention mask to allow attention to a few extra padding tokens.
|
||||
|
||||
Args:
|
||||
mask: Original attention mask (1 for tokens to attend to, 0 for masked tokens)
|
||||
max_seq_length: Maximum sequence length of the model
|
||||
num_extra_padding: Number of padding tokens to unmask
|
||||
|
||||
Returns:
|
||||
Modified mask
|
||||
"""
|
||||
# Get the actual sequence length from the mask
|
||||
seq_length = mask.sum(dim=-1)
|
||||
batch_size = mask.shape[0]
|
||||
|
||||
modified_mask = mask.clone()
|
||||
|
||||
for i in range(batch_size):
|
||||
current_seq_len = int(seq_length[i].item())
|
||||
|
||||
# Only add extra padding tokens if there's room
|
||||
if current_seq_len < max_seq_length:
|
||||
# Calculate how many padding tokens we can unmask
|
||||
available_padding = max_seq_length - current_seq_len
|
||||
tokens_to_unmask = min(num_extra_padding, available_padding)
|
||||
|
||||
# Unmask the specified number of padding tokens right after the sequence
|
||||
modified_mask[i, current_seq_len : current_seq_len + tokens_to_unmask] = 1
|
||||
|
||||
return modified_mask
|
||||
|
||||
|
||||
class Chroma(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, params: ChromaParams):
|
||||
super().__init__()
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
self.gradient_checkpointing = False
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(
|
||||
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
|
||||
)
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(
|
||||
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
|
||||
)
|
||||
# self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
# patchify ops
|
||||
self.img_in_patch = nn.Conv2d(
|
||||
params.in_channels,
|
||||
params.hidden_size,
|
||||
kernel_size=params.patch_size,
|
||||
stride=params.patch_size,
|
||||
bias=True
|
||||
)
|
||||
nn.init.zeros_(self.img_in_patch.weight)
|
||||
nn.init.zeros_(self.img_in_patch.bias)
|
||||
# TODO: need proper mapping for this approximator output!
|
||||
# currently the mapping is hardcoded in distribute_modulations function
|
||||
self.distilled_guidance_layer = Approximator(
|
||||
params.approximator_in_dim,
|
||||
self.hidden_size,
|
||||
params.approximator_hidden_size,
|
||||
params.approximator_depth,
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
use_compiled=params._use_compiled,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
use_compiled=params._use_compiled,
|
||||
)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
# self.final_layer = LastLayer(
|
||||
# self.hidden_size,
|
||||
# 1,
|
||||
# self.out_channels,
|
||||
# use_compiled=params._use_compiled,
|
||||
# )
|
||||
|
||||
# pixel channel concat with DCT
|
||||
self.nerf_image_embedder = NerfEmbedder(
|
||||
in_channels=params.in_channels,
|
||||
hidden_size_input=params.nerf_hidden_size,
|
||||
max_freqs=params.nerf_max_freqs
|
||||
)
|
||||
|
||||
self.nerf_blocks = nn.ModuleList([
|
||||
NerfGLUBlock(
|
||||
hidden_size_s=params.hidden_size,
|
||||
hidden_size_x=params.nerf_hidden_size,
|
||||
mlp_ratio=params.nerf_mlp_ratio,
|
||||
use_compiled=params._use_compiled
|
||||
) for _ in range(params.nerf_depth)
|
||||
])
|
||||
# self.nerf_final_layer = NerfFinalLayer(
|
||||
# params.nerf_hidden_size,
|
||||
# out_channels=params.in_channels,
|
||||
# use_compiled=params._use_compiled
|
||||
# )
|
||||
self.nerf_final_layer_conv = NerfFinalLayerConv(
|
||||
params.nerf_hidden_size,
|
||||
out_channels=params.in_channels,
|
||||
use_compiled=params._use_compiled
|
||||
)
|
||||
# TODO: move this hardcoded value to config
|
||||
# single layer has 3 modulation vectors
|
||||
# double layer has 6 modulation vectors for each expert
|
||||
# final layer has 2 modulation vectors
|
||||
self.mod_index_length = 3 * params.depth_single_blocks + 2 * 6 * params.depth + 2
|
||||
self.depth_single_blocks = params.depth_single_blocks
|
||||
self.depth_double_blocks = params.depth
|
||||
# self.mod_index = torch.tensor(list(range(self.mod_index_length)), device=0)
|
||||
self.register_buffer(
|
||||
"mod_index",
|
||||
torch.tensor(list(range(self.mod_index_length)), device="cpu"),
|
||||
persistent=False,
|
||||
)
|
||||
self.approximator_in_dim = params.approximator_in_dim
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def enable_gradient_checkpointing(self, enable: bool = True):
|
||||
self.gradient_checkpointing = enable
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
txt_mask: Tensor,
|
||||
timesteps: Tensor,
|
||||
guidance: Tensor,
|
||||
attn_padding: int = 1,
|
||||
) -> Tensor:
|
||||
if img.ndim != 4:
|
||||
raise ValueError("Input img tensor must be in [B, C, H, W] format.")
|
||||
if txt.ndim != 3:
|
||||
raise ValueError("Input txt tensors must have 3 dimensions.")
|
||||
B, C, H, W = img.shape
|
||||
|
||||
# gemini gogogo idk how to unfold and pack the patch properly :P
|
||||
# Store the raw pixel values of each patch for the NeRF head later.
|
||||
# unfold creates patches: [B, C * P * P, NumPatches]
|
||||
nerf_pixels = nn.functional.unfold(img, kernel_size=self.params.patch_size, stride=self.params.patch_size)
|
||||
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
|
||||
|
||||
# partchify ops
|
||||
img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P]
|
||||
num_patches = img.shape[2] * img.shape[3]
|
||||
# flatten into a sequence for the transformer.
|
||||
img = img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden]
|
||||
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
# TODO:
|
||||
# need to fix grad accumulation issue here for now it's in no grad mode
|
||||
# besides, i don't want to wash out the PFP that's trained on this model weights anyway
|
||||
# the fan out operation here is deleting the backward graph
|
||||
# alternatively doing forward pass for every block manually is doable but slow
|
||||
# custom backward probably be better
|
||||
with torch.no_grad():
|
||||
distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim//4)
|
||||
# TODO: need to add toggle to omit this from schnell but that's not a priority
|
||||
distil_guidance = timestep_embedding(guidance, self.approximator_in_dim//4)
|
||||
# get all modulation index
|
||||
modulation_index = timestep_embedding(self.mod_index, self.approximator_in_dim//2)
|
||||
# we need to broadcast the modulation index here so each batch has all of the index
|
||||
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1)
|
||||
# and we need to broadcast timestep and guidance along too
|
||||
timestep_guidance = (
|
||||
torch.cat([distill_timestep, distil_guidance], dim=1)
|
||||
.unsqueeze(1)
|
||||
.repeat(1, self.mod_index_length, 1)
|
||||
)
|
||||
# then and only then we could concatenate it together
|
||||
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1)
|
||||
mod_vectors = self.distilled_guidance_layer(input_vec.requires_grad_(True))
|
||||
mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
# compute mask
|
||||
# assume max seq length from the batched input
|
||||
|
||||
max_len = txt.shape[1]
|
||||
|
||||
# mask
|
||||
with torch.no_grad():
|
||||
txt_mask_w_padding = modify_mask_to_attend_padding(
|
||||
txt_mask, max_len, attn_padding
|
||||
)
|
||||
txt_img_mask = torch.cat(
|
||||
[
|
||||
txt_mask_w_padding,
|
||||
torch.ones([img.shape[0], img.shape[1]], device=txt_mask.device),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
txt_img_mask = txt_img_mask.float().T @ txt_img_mask.float()
|
||||
txt_img_mask = (
|
||||
txt_img_mask[None, None, ...]
|
||||
.repeat(txt.shape[0], self.num_heads, 1, 1)
|
||||
.int()
|
||||
.bool()
|
||||
)
|
||||
# txt_mask_w_padding[txt_mask_w_padding==False] = True
|
||||
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
# the guidance replaced by FFN output
|
||||
img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"]
|
||||
txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"]
|
||||
double_mod = [img_mod, txt_mod]
|
||||
|
||||
# just in case in different GPU for simple pipeline parallel
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
img.requires_grad_(True)
|
||||
img, txt = ckpt.checkpoint(
|
||||
block, img, txt, pe, double_mod, txt_img_mask
|
||||
)
|
||||
else:
|
||||
img, txt = block(
|
||||
img=img, txt=txt, pe=pe, distill_vec=double_mod, mask=txt_img_mask
|
||||
)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"]
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
img.requires_grad_(True)
|
||||
img = ckpt.checkpoint(block, img, pe, single_mod, txt_img_mask)
|
||||
else:
|
||||
img = block(img, pe=pe, distill_vec=single_mod, mask=txt_img_mask)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
# final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"]
|
||||
# img = self.final_layer(
|
||||
# img, distill_vec=final_mod
|
||||
# ) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
# aliasing
|
||||
nerf_hidden = img
|
||||
# reshape for per-patch processing
|
||||
nerf_hidden = nerf_hidden.reshape(B * num_patches, self.params.hidden_size)
|
||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, self.params.patch_size**2).transpose(1, 2)
|
||||
|
||||
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct = self.nerf_image_embedder(nerf_pixels)
|
||||
|
||||
# pass through the dynamic MLP blocks (the NeRF)
|
||||
for i, block in enumerate(self.nerf_blocks):
|
||||
if self.training:
|
||||
img_dct = ckpt.checkpoint(block, img_dct, nerf_hidden)
|
||||
else:
|
||||
img_dct = block(img_dct, nerf_hidden)
|
||||
|
||||
# final projection to get the output pixel values
|
||||
# img_dct = self.nerf_final_layer(img_dct) # -> [B*NumPatches, P*P, C]
|
||||
img_dct = self.nerf_final_layer_conv.norm(img_dct)
|
||||
|
||||
# gemini gogogo idk how to fold this properly :P
|
||||
# Reassemble the patches into the final image.
|
||||
img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P]
|
||||
# Reshape to combine with batch dimension for fold
|
||||
img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P]
|
||||
img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches]
|
||||
img_dct = nn.functional.fold(
|
||||
img_dct,
|
||||
output_size=(H, W),
|
||||
kernel_size=self.params.patch_size,
|
||||
stride=self.params.patch_size
|
||||
) # [B, Hidden, H, W]
|
||||
img_dct = self.nerf_final_layer_conv.conv(img_dct)
|
||||
|
||||
return img_dct
|
||||
134
toolkit/models/FakeVAE.py
Normal file
134
toolkit/models/FakeVAE.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from diffusers import AutoencoderKL
|
||||
from typing import Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKLOutput
|
||||
from diffusers.models.autoencoders.vae import DecoderOutput
|
||||
|
||||
|
||||
class Config:
|
||||
in_channels = 3
|
||||
out_channels = 3
|
||||
down_block_types = ("1",)
|
||||
up_block_types = ("1",)
|
||||
block_out_channels = (1,)
|
||||
latent_channels = 3 # usually 4
|
||||
norm_num_groups = 1
|
||||
sample_size = 512
|
||||
scaling_factor = 1.0
|
||||
# scaling_factor = 1.8
|
||||
shift_factor = 0
|
||||
|
||||
def __getitem__(cls, x):
|
||||
return getattr(cls, x)
|
||||
|
||||
|
||||
class FakeVAE(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._dtype = torch.float32
|
||||
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.config = Config()
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@dtype.setter
|
||||
def dtype(self, value):
|
||||
self._dtype = value
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
@device.setter
|
||||
def device(self, value):
|
||||
self._device = value
|
||||
|
||||
# mimic to from torch
|
||||
def to(self, *args, **kwargs):
|
||||
# pull out dtype and device if they exist
|
||||
if "dtype" in kwargs:
|
||||
self._dtype = kwargs["dtype"]
|
||||
if "device" in kwargs:
|
||||
self._device = kwargs["device"]
|
||||
return super().to(*args, **kwargs)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
pass
|
||||
|
||||
# @apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.FloatTensor, return_dict: bool = True
|
||||
) -> AutoencoderKLOutput:
|
||||
h = x
|
||||
|
||||
# moments = self.quant_conv(h)
|
||||
# posterior = DiagonalGaussianDistribution(moments)
|
||||
|
||||
if not return_dict:
|
||||
return (h,)
|
||||
|
||||
class FakeDist:
|
||||
def __init__(self, x):
|
||||
self._sample = x
|
||||
|
||||
def sample(self):
|
||||
return self._sample
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=FakeDist(h))
|
||||
|
||||
def _decode(
|
||||
self, z: torch.FloatTensor, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
dec = z
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
# @apply_forward_hook
|
||||
def decode(
|
||||
self, z: torch.FloatTensor, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
decoded = self._decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
pass
|
||||
|
||||
def enable_tiling(self, use_tiling: bool = True):
|
||||
pass
|
||||
|
||||
def disable_tiling(self):
|
||||
pass
|
||||
|
||||
def enable_slicing(self):
|
||||
pass
|
||||
|
||||
def disable_slicing(self):
|
||||
pass
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, value: bool = True):
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
dec = sample
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
Reference in New Issue
Block a user