mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
449 lines
16 KiB
Python
449 lines
16 KiB
Python
import os
|
|
from typing import TYPE_CHECKING
|
|
|
|
import torch
|
|
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
|
from PIL import Image
|
|
from toolkit.models.base_model import BaseModel
|
|
from toolkit.basic import flush
|
|
from diffusers import AutoencoderKL
|
|
# from toolkit.pixel_shuffle_encoder import AutoencoderPixelMixer
|
|
from toolkit.prompt_utils import PromptEmbeds
|
|
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
|
from toolkit.dequantize import patch_dequantization_on_save
|
|
from toolkit.accelerator import unwrap_model
|
|
from optimum.quanto import freeze, QTensor
|
|
from toolkit.util.quantize import quantize, get_qtype
|
|
from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer
|
|
from .pipeline import ChromaPipeline
|
|
from einops import rearrange, repeat
|
|
import random
|
|
import torch.nn.functional as F
|
|
from .src.model import Chroma, chroma_params
|
|
from safetensors.torch import load_file, save_file
|
|
from toolkit.metadata import get_meta_for_safetensors
|
|
import huggingface_hub
|
|
|
|
if TYPE_CHECKING:
|
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
|
|
|
scheduler_config = {
|
|
"base_image_seq_len": 256,
|
|
"base_shift": 0.5,
|
|
"max_image_seq_len": 4096,
|
|
"max_shift": 1.15,
|
|
"num_train_timesteps": 1000,
|
|
"shift": 3.0,
|
|
"use_dynamic_shifting": True
|
|
}
|
|
|
|
class FakeConfig:
|
|
# for diffusers compatability
|
|
def __init__(self):
|
|
self.attention_head_dim = 128
|
|
self.guidance_embeds = True
|
|
self.in_channels = 64
|
|
self.joint_attention_dim = 4096
|
|
self.num_attention_heads = 24
|
|
self.num_layers = 19
|
|
self.num_single_layers = 38
|
|
self.patch_size = 1
|
|
|
|
class FakeCLIP(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.dtype = torch.bfloat16
|
|
self.device = 'cuda'
|
|
self.text_model = None
|
|
self.tokenizer = None
|
|
self.model_max_length = 77
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return torch.zeros(1, 1, 1).to(self.device)
|
|
|
|
|
|
class ChromaModel(BaseModel):
|
|
arch = "chroma"
|
|
|
|
def __init__(
|
|
self,
|
|
device,
|
|
model_config: ModelConfig,
|
|
dtype='bf16',
|
|
custom_pipeline=None,
|
|
noise_scheduler=None,
|
|
**kwargs
|
|
):
|
|
super().__init__(
|
|
device,
|
|
model_config,
|
|
dtype,
|
|
custom_pipeline,
|
|
noise_scheduler,
|
|
**kwargs
|
|
)
|
|
self.is_flow_matching = True
|
|
self.is_transformer = True
|
|
self.target_lora_modules = ['Chroma']
|
|
|
|
# static method to get the noise scheduler
|
|
@staticmethod
|
|
def get_train_scheduler():
|
|
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
|
|
|
def get_bucket_divisibility(self):
|
|
# return the bucket divisibility for the model
|
|
return 32
|
|
|
|
def load_model(self):
|
|
dtype = self.torch_dtype
|
|
|
|
# will be updated if we detect a existing checkpoint in training folder
|
|
model_path = self.model_config.name_or_path
|
|
|
|
if model_path == "lodestones/Chroma":
|
|
print("Looking for latest Chroma checkpoint")
|
|
# get the latest checkpoint
|
|
files_list = huggingface_hub.list_repo_files(model_path)
|
|
print(files_list)
|
|
latest_version = 28 # current latest version at time of writing
|
|
while True:
|
|
if f"chroma-unlocked-v{latest_version}.safetensors" not in files_list:
|
|
latest_version -= 1
|
|
break
|
|
else:
|
|
latest_version += 1
|
|
print(f"Using latest Chroma version: v{latest_version}")
|
|
|
|
# make sure we have it
|
|
model_path = huggingface_hub.hf_hub_download(
|
|
repo_id=model_path,
|
|
filename=f"chroma-unlocked-v{latest_version}.safetensors",
|
|
)
|
|
elif model_path.startswith("lodestones/Chroma/v"):
|
|
# get the version number
|
|
version = model_path.split("/")[-1].split("v")[-1]
|
|
print(f"Using Chroma version: v{version}")
|
|
# make sure we have it
|
|
model_path = huggingface_hub.hf_hub_download(
|
|
repo_id='lodestones/Chroma',
|
|
filename=f"chroma-unlocked-v{version}.safetensors",
|
|
)
|
|
else:
|
|
# check if the model path is a local file
|
|
if os.path.exists(model_path):
|
|
print(f"Using local model: {model_path}")
|
|
else:
|
|
raise ValueError(f"Model path {model_path} does not exist")
|
|
|
|
# extras_path = 'black-forest-labs/FLUX.1-schnell'
|
|
# schnell model is gated now, use flex instead
|
|
extras_path = 'ostris/Flex.1-alpha'
|
|
|
|
self.print_and_status_update("Loading transformer")
|
|
|
|
chroma_state_dict = load_file(model_path, 'cpu')
|
|
|
|
# determine number of double and single blocks
|
|
double_blocks = 0
|
|
single_blocks = 0
|
|
for key in chroma_state_dict.keys():
|
|
if "double_blocks" in key:
|
|
block_num = int(key.split(".")[1]) + 1
|
|
if block_num > double_blocks:
|
|
double_blocks = block_num
|
|
elif "single_blocks" in key:
|
|
block_num = int(key.split(".")[1]) + 1
|
|
if block_num > single_blocks:
|
|
single_blocks = block_num
|
|
print(f"Double Blocks: {double_blocks}")
|
|
print(f"Single Blocks: {single_blocks}")
|
|
|
|
chroma_params.depth = double_blocks
|
|
chroma_params.depth_single_blocks = single_blocks
|
|
transformer = Chroma(chroma_params)
|
|
|
|
# add dtype, not sure why it doesnt have it
|
|
transformer.dtype = dtype
|
|
# load the state dict into the model
|
|
transformer.load_state_dict(chroma_state_dict)
|
|
|
|
transformer.to(self.quantize_device, dtype=dtype)
|
|
|
|
transformer.config = FakeConfig()
|
|
transformer.config.num_layers = double_blocks
|
|
transformer.config.num_single_layers = single_blocks
|
|
|
|
if self.model_config.quantize:
|
|
# patch the state dict method
|
|
patch_dequantization_on_save(transformer)
|
|
quantization_type = get_qtype(self.model_config.qtype)
|
|
self.print_and_status_update("Quantizing transformer")
|
|
quantize(transformer, weights=quantization_type,
|
|
**self.model_config.quantize_kwargs)
|
|
freeze(transformer)
|
|
transformer.to(self.device_torch)
|
|
else:
|
|
transformer.to(self.device_torch, dtype=dtype)
|
|
|
|
flush()
|
|
|
|
self.print_and_status_update("Loading T5")
|
|
tokenizer_2 = T5TokenizerFast.from_pretrained(
|
|
extras_path, subfolder="tokenizer_2", torch_dtype=dtype
|
|
)
|
|
text_encoder_2 = T5EncoderModel.from_pretrained(
|
|
extras_path, subfolder="text_encoder_2", torch_dtype=dtype
|
|
)
|
|
text_encoder_2.to(self.device_torch, dtype=dtype)
|
|
flush()
|
|
|
|
if self.model_config.quantize_te:
|
|
self.print_and_status_update("Quantizing T5")
|
|
quantize(text_encoder_2, weights=get_qtype(
|
|
self.model_config.qtype))
|
|
freeze(text_encoder_2)
|
|
flush()
|
|
|
|
# self.print_and_status_update("Loading CLIP")
|
|
text_encoder = FakeCLIP()
|
|
tokenizer = FakeCLIP()
|
|
text_encoder.to(self.device_torch, dtype=dtype)
|
|
|
|
self.noise_scheduler = ChromaModel.get_train_scheduler()
|
|
|
|
self.print_and_status_update("Loading VAE")
|
|
vae = AutoencoderKL.from_pretrained(
|
|
extras_path,
|
|
subfolder="vae",
|
|
torch_dtype=dtype
|
|
)
|
|
vae = vae.to(self.device_torch, dtype=dtype)
|
|
|
|
self.print_and_status_update("Making pipe")
|
|
|
|
pipe: ChromaPipeline = ChromaPipeline(
|
|
scheduler=self.noise_scheduler,
|
|
text_encoder=text_encoder,
|
|
tokenizer=tokenizer,
|
|
text_encoder_2=None,
|
|
tokenizer_2=tokenizer_2,
|
|
vae=vae,
|
|
transformer=None,
|
|
)
|
|
# for quantization, it works best to do these after making the pipe
|
|
pipe.text_encoder_2 = text_encoder_2
|
|
pipe.transformer = transformer
|
|
|
|
self.print_and_status_update("Preparing Model")
|
|
|
|
text_encoder = [pipe.text_encoder, pipe.text_encoder_2]
|
|
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
|
|
|
pipe.transformer = pipe.transformer.to(self.device_torch)
|
|
|
|
flush()
|
|
# just to make sure everything is on the right device and dtype
|
|
text_encoder[0].to(self.device_torch)
|
|
text_encoder[0].requires_grad_(False)
|
|
text_encoder[0].eval()
|
|
text_encoder[1].to(self.device_torch)
|
|
text_encoder[1].requires_grad_(False)
|
|
text_encoder[1].eval()
|
|
pipe.transformer = pipe.transformer.to(self.device_torch)
|
|
flush()
|
|
|
|
# save it to the model class
|
|
self.vae = vae
|
|
self.text_encoder = text_encoder # list of text encoders
|
|
self.tokenizer = tokenizer # list of tokenizers
|
|
self.model = pipe.transformer
|
|
self.pipeline = pipe
|
|
self.print_and_status_update("Model Loaded")
|
|
|
|
def get_generation_pipeline(self):
|
|
scheduler = ChromaModel.get_train_scheduler()
|
|
pipeline = ChromaPipeline(
|
|
scheduler=scheduler,
|
|
text_encoder=unwrap_model(self.text_encoder[0]),
|
|
tokenizer=self.tokenizer[0],
|
|
text_encoder_2=unwrap_model(self.text_encoder[1]),
|
|
tokenizer_2=self.tokenizer[1],
|
|
vae=unwrap_model(self.vae),
|
|
transformer=unwrap_model(self.transformer)
|
|
)
|
|
|
|
# pipeline = pipeline.to(self.device_torch)
|
|
|
|
return pipeline
|
|
|
|
def generate_single_image(
|
|
self,
|
|
pipeline: ChromaPipeline,
|
|
gen_config: GenerateImageConfig,
|
|
conditional_embeds: PromptEmbeds,
|
|
unconditional_embeds: PromptEmbeds,
|
|
generator: torch.Generator,
|
|
extra: dict,
|
|
):
|
|
|
|
extra['negative_prompt_embeds'] = unconditional_embeds.text_embeds
|
|
extra['negative_prompt_attn_mask'] = unconditional_embeds.attention_mask
|
|
|
|
img = pipeline(
|
|
prompt_embeds=conditional_embeds.text_embeds,
|
|
prompt_attn_mask=conditional_embeds.attention_mask,
|
|
height=gen_config.height,
|
|
width=gen_config.width,
|
|
num_inference_steps=gen_config.num_inference_steps,
|
|
guidance_scale=gen_config.guidance_scale,
|
|
latents=gen_config.latents,
|
|
generator=generator,
|
|
**extra
|
|
).images[0]
|
|
return img
|
|
|
|
def get_noise_prediction(
|
|
self,
|
|
latent_model_input: torch.Tensor,
|
|
timestep: torch.Tensor, # 0 to 1000 scale
|
|
text_embeddings: PromptEmbeds,
|
|
**kwargs
|
|
):
|
|
with torch.no_grad():
|
|
bs, c, h, w = latent_model_input.shape
|
|
latent_model_input_packed = rearrange(
|
|
latent_model_input,
|
|
"b c (h ph) (w pw) -> b (h w) (c ph pw)",
|
|
ph=2,
|
|
pw=2
|
|
)
|
|
|
|
img_ids = torch.zeros(h // 2, w // 2, 3)
|
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
|
img_ids = repeat(img_ids, "h w c -> b (h w) c",
|
|
b=bs).to(self.device_torch)
|
|
|
|
txt_ids = torch.zeros(
|
|
bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
|
|
|
|
guidance = torch.full([1], 0, device=self.device_torch, dtype=torch.float32)
|
|
guidance = guidance.expand(latent_model_input_packed.shape[0])
|
|
|
|
cast_dtype = self.unet.dtype
|
|
|
|
noise_pred = self.unet(
|
|
img=latent_model_input_packed.to(
|
|
self.device_torch, cast_dtype
|
|
),
|
|
img_ids=img_ids,
|
|
txt=text_embeddings.text_embeds.to(
|
|
self.device_torch, cast_dtype
|
|
),
|
|
txt_ids=txt_ids,
|
|
txt_mask=text_embeddings.attention_mask.to(
|
|
self.device_torch, cast_dtype
|
|
),
|
|
timesteps=timestep / 1000,
|
|
guidance=guidance
|
|
)
|
|
|
|
if isinstance(noise_pred, QTensor):
|
|
noise_pred = noise_pred.dequantize()
|
|
|
|
noise_pred = rearrange(
|
|
noise_pred,
|
|
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
|
h=latent_model_input.shape[2] // 2,
|
|
w=latent_model_input.shape[3] // 2,
|
|
ph=2,
|
|
pw=2,
|
|
c=self.vae.config.latent_channels
|
|
)
|
|
|
|
return noise_pred
|
|
|
|
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
|
if isinstance(prompt, str):
|
|
prompts = [prompt]
|
|
else:
|
|
prompts = prompt
|
|
if self.pipeline.text_encoder.device != self.device_torch:
|
|
self.pipeline.text_encoder.to(self.device_torch)
|
|
|
|
max_length = 512
|
|
|
|
device = self.text_encoder[1].device
|
|
dtype = self.text_encoder[1].dtype
|
|
|
|
# T5
|
|
text_inputs = self.tokenizer[1](
|
|
prompts,
|
|
padding="max_length",
|
|
max_length=max_length,
|
|
truncation=True,
|
|
return_length=False,
|
|
return_overflowing_tokens=False,
|
|
return_tensors="pt",
|
|
)
|
|
text_input_ids = text_inputs.input_ids
|
|
|
|
prompt_embeds = self.text_encoder[1](text_input_ids.to(device), output_hidden_states=False)[0]
|
|
|
|
dtype = self.text_encoder[1].dtype
|
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
|
|
|
prompt_attention_mask = text_inputs["attention_mask"]
|
|
|
|
pe = PromptEmbeds(
|
|
prompt_embeds
|
|
)
|
|
pe.attention_mask = prompt_attention_mask
|
|
return pe
|
|
|
|
def get_model_has_grad(self):
|
|
# return from a weight if it has grad
|
|
return self.model.final_layer.linear.weight.requires_grad
|
|
|
|
def get_te_has_grad(self):
|
|
# return from a weight if it has grad
|
|
return self.text_encoder[1].encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
|
|
|
|
def save_model(self, output_path, meta, save_dtype):
|
|
# only save the unet
|
|
transformer: Chroma = unwrap_model(self.model)
|
|
state_dict = transformer.state_dict()
|
|
save_dict = {}
|
|
for k, v in state_dict.items():
|
|
if isinstance(v, QTensor):
|
|
v = v.dequantize()
|
|
save_dict[k] = v.clone().to('cpu', dtype=save_dtype)
|
|
|
|
meta = get_meta_for_safetensors(meta, name='chroma')
|
|
save_file(save_dict, output_path, metadata=meta)
|
|
|
|
def get_loss_target(self, *args, **kwargs):
|
|
noise = kwargs.get('noise')
|
|
batch = kwargs.get('batch')
|
|
return (noise - batch.latents).detach()
|
|
|
|
def convert_lora_weights_before_save(self, state_dict):
|
|
# currently starte with transformer. but needs to start with diffusion_model. for comfyui
|
|
new_sd = {}
|
|
for key, value in state_dict.items():
|
|
new_key = key.replace("transformer.", "diffusion_model.")
|
|
new_sd[new_key] = value
|
|
return new_sd
|
|
|
|
def convert_lora_weights_before_load(self, state_dict):
|
|
# saved as diffusion_model. but needs to be transformer. for ai-toolkit
|
|
new_sd = {}
|
|
for key, value in state_dict.items():
|
|
new_key = key.replace("diffusion_model.", "transformer.")
|
|
new_sd[new_key] = value
|
|
return new_sd
|
|
|
|
def get_base_model_version(self):
|
|
return "chroma"
|