Partial implementation for training auraflow.

This commit is contained in:
Jaret Burkett
2024-07-12 12:11:38 -06:00
parent c062b7716c
commit e4558dff4b
9 changed files with 386 additions and 19 deletions

View File

@@ -883,6 +883,26 @@ class SDTrainer(BaseSDTrainProcess):
def end_of_training_loop(self):
pass
def predict_noise(
self,
noisy_latents: torch.Tensor,
timesteps: Union[int, torch.Tensor] = 1,
conditional_embeds: Union[PromptEmbeds, None] = None,
unconditional_embeds: Union[PromptEmbeds, None] = None,
**kwargs,
):
dtype = get_torch_dtype(self.train_config.dtype)
return self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeddings=unconditional_embeds,
timestep=timesteps,
guidance_scale=self.train_config.cfg_scale,
detach_unconditional=False,
rescale_cfg=self.train_config.cfg_rescale,
**kwargs
)
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
self.timer.start('preprocess_batch')
batch = self.preprocess_batch(batch)
@@ -1453,14 +1473,11 @@ class SDTrainer(BaseSDTrainProcess):
with self.timer('predict_unet'):
if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeddings=unconditional_embeds,
timestep=timesteps,
guidance_scale=self.train_config.cfg_scale,
detach_unconditional=False,
rescale_cfg=self.train_config.cfg_rescale,
noise_pred = self.predict_noise(
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
timesteps=timesteps,
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeds=unconditional_embeds,
**pred_kwargs
)
self.after_unet_predict()

View File

@@ -1287,6 +1287,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
is_v2=self.model_config.is_v2,
is_v3=self.model_config.is_v3,
is_pixart=self.model_config.is_pixart,
is_auraflow=self.model_config.is_auraflow,
is_ssd=self.model_config.is_ssd,
is_vega=self.model_config.is_vega,
dropout=self.network_config.dropout,

View File

@@ -53,6 +53,50 @@ resolutions_1024: List[BucketResolution] = [
{"width": 512, "height": 2048},
]
# Even numbers so they can be patched easier
resolutions_dit_1024: List[BucketResolution] = [
# Base resolution
{"width": 1024, "height": 1024},
# widescreen
{"width": 2048, "height": 512},
{"width": 1792, "height": 576},
{"width": 1728, "height": 576},
{"width": 1664, "height": 576},
{"width": 1600, "height": 640},
{"width": 1536, "height": 640},
{"width": 1472, "height": 704},
{"width": 1408, "height": 704},
{"width": 1344, "height": 704},
{"width": 1344, "height": 768},
{"width": 1280, "height": 768},
{"width": 1216, "height": 832},
{"width": 1152, "height": 832},
{"width": 1152, "height": 896},
{"width": 1088, "height": 896},
{"width": 1088, "height": 960},
{"width": 1024, "height": 960},
# portrait
{"width": 960, "height": 1024},
{"width": 960, "height": 1088},
{"width": 896, "height": 1088},
{"width": 896, "height": 1152}, # 2:3
{"width": 832, "height": 1152},
{"width": 832, "height": 1216},
{"width": 768, "height": 1280},
{"width": 768, "height": 1344},
{"width": 704, "height": 1408},
{"width": 704, "height": 1472},
{"width": 640, "height": 1536},
{"width": 640, "height": 1600},
{"width": 576, "height": 1664},
{"width": 576, "height": 1728},
{"width": 576, "height": 1792},
{"width": 512, "height": 1856},
{"width": 512, "height": 1920},
{"width": 512, "height": 1984},
{"width": 512, "height": 2048},
]
def get_bucket_sizes(resolution: int = 512, divisibility: int = 8) -> List[BucketResolution]:
# determine scaler form 1024 to resolution

View File

@@ -350,6 +350,7 @@ class ModelConfig:
self.is_xl: bool = kwargs.get('is_xl', False)
self.is_pixart: bool = kwargs.get('is_pixart', False)
self.is_pixart_sigma: bool = kwargs.get('is_pixart_sigma', False)
self.is_auraflow: bool = kwargs.get('is_auraflow', False)
self.is_v3: bool = kwargs.get('is_v3', False)
if self.is_pixart_sigma:
self.is_pixart = True
@@ -381,7 +382,7 @@ class ModelConfig:
self.is_xl = True
# for text encoder quant. Only works with pixart currently
self.text_encoder_bits = kwargs.get('text_encoder_bits', 8) # 16, 8, 4
self.text_encoder_bits = kwargs.get('text_encoder_bits', 16) # 16, 8, 4
self.unet_path = kwargs.get("unet_path", None)
self.unet_sample_size = kwargs.get("unet_sample_size", None)

View File

@@ -1355,6 +1355,10 @@ class LatentCachingMixin:
file_item.latent_space_version = 'sdxl'
elif self.sd.is_v3:
file_item.latent_space_version = 'sd3'
elif self.sd.is_auraflow:
file_item.latent_space_version = 'sdxl'
elif self.sd.model_config.is_pixart_sigma:
file_item.latent_space_version = 'sdxl'
else:
file_item.latent_space_version = 'sd1'
file_item.is_caching_to_disk = to_disk

View File

@@ -7,7 +7,7 @@ import re
import sys
from typing import List, Optional, Dict, Type, Union
import torch
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel
from transformers import CLIPTextModel
from .config_modules import NetworkConfig
@@ -158,6 +158,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
is_v2=False,
is_v3=False,
is_pixart: bool = False,
is_auraflow: bool = False,
use_bias: bool = False,
is_lorm: bool = False,
ignore_if_contains = None,
@@ -212,6 +213,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.is_v2 = is_v2
self.is_v3 = is_v3
self.is_pixart = is_pixart
self.is_auraflow = is_auraflow
self.network_type = network_type
if self.network_type.lower() == "dora":
self.module_class = DoRAModule
@@ -246,7 +248,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
target_replace_modules: List[torch.nn.Module],
) -> List[LoRAModule]:
unet_prefix = self.LORA_PREFIX_UNET
if is_pixart or is_v3:
if is_pixart or is_v3 or is_auraflow:
unet_prefix = f"lora_transformer"
prefix = (
@@ -371,6 +373,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if is_pixart:
target_modules = ["PixArtTransformer2DModel"]
if is_auraflow:
target_modules = ["AuraFlowTransformer2DModel"]
if train_unet:
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
else:
@@ -408,6 +413,14 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
transformer.pos_embed = self.transformer_pos_embed
transformer.proj_out = self.transformer_proj_out
elif self.is_auraflow:
transformer: AuraFlowTransformer2DModel = unet
self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed)
self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
transformer.pos_embed = self.transformer_pos_embed
transformer.proj_out = self.transformer_proj_out
else:
unet: UNet2DConditionModel = unet
unet_conv_in: torch.nn.Conv2d = unet.conv_in
@@ -424,7 +437,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)
if self.full_train_in_out:
if self.is_pixart:
if self.is_pixart or self.is_auraflow:
all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
else:

127
toolkit/models/auraflow.py Normal file
View File

@@ -0,0 +1,127 @@
import math
from functools import partial
from torch import nn
import torch
class AuraFlowPatchEmbed(nn.Module):
def __init__(
self,
height=224,
width=224,
patch_size=16,
in_channels=3,
embed_dim=768,
pos_embed_max_size=None,
):
super().__init__()
self.num_patches = (height // patch_size) * (width // patch_size)
self.pos_embed_max_size = pos_embed_max_size
self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim)
self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1)
self.patch_size = patch_size
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size
def forward(self, latent):
batch_size, num_channels, height, width = latent.size()
latent = latent.view(
batch_size,
num_channels,
height // self.patch_size,
self.patch_size,
width // self.patch_size,
self.patch_size,
)
latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
latent = self.proj(latent)
try:
return latent + self.pos_embed
except RuntimeError:
raise RuntimeError(
f"Positional embeddings are too small for the number of patches. "
f"Please increase `pos_embed_max_size` to at least {self.num_patches}."
)
# comfy
# def apply_pos_embeds(self, x, h, w):
# h = (h + 1) // self.patch_size
# w = (w + 1) // self.patch_size
# max_dim = max(h, w)
#
# cur_dim = self.h_max
# pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype)
#
# if max_dim > cur_dim:
# pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1,
# -1)
# cur_dim = max_dim
#
# from_h = (cur_dim - h) // 2
# from_w = (cur_dim - w) // 2
# pos_encoding = pos_encoding[:, from_h:from_h + h, from_w:from_w + w]
# return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
# def patchify(self, x):
# B, C, H, W = x.size()
# pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
# pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
#
# x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
# x = x.view(
# B,
# C,
# (H + 1) // self.patch_size,
# self.patch_size,
# (W + 1) // self.patch_size,
# self.patch_size,
# )
# x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
# return x
def patch_auraflow_pos_embed(pos_embed):
# we need to hijack the forward and replace with a custom one. Self is the model
def new_forward(self, latent):
batch_size, num_channels, height, width = latent.size()
# add padding to the latent to make it match pos_embed
latent_size = height * width * num_channels / 16 # todo check where 16 comes from?
pos_embed_size = self.pos_embed.shape[1]
if latent_size < pos_embed_size:
total_padding = int(pos_embed_size - math.floor(latent_size))
total_padding = total_padding // 16
pad_height = total_padding // 2
pad_width = total_padding - pad_height
# mirror padding on the right side
padding = (0, pad_width, 0, pad_height)
latent = torch.nn.functional.pad(latent, padding, mode='reflect')
elif latent_size > pos_embed_size:
amount_to_remove = latent_size - pos_embed_size
latent = latent[:, :, :-amount_to_remove]
batch_size, num_channels, height, width = latent.size()
latent = latent.view(
batch_size,
num_channels,
height // self.patch_size,
self.patch_size,
width // self.patch_size,
self.patch_size,
)
latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
latent = self.proj(latent)
try:
return latent + self.pos_embed
except RuntimeError:
raise RuntimeError(
f"Positional embeddings are too small for the number of patches. "
f"Please increase `pos_embed_max_size` to at least {self.num_patches}."
)
pos_embed.forward = partial(new_forward, pos_embed)

View File

@@ -27,6 +27,7 @@ from library.model_util import convert_unet_state_dict_to_sd, convert_text_encod
from toolkit import train_tools
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors
from toolkit.models.auraflow import patch_auraflow_pos_embed
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
from toolkit.reference_adapter import ReferenceAdapter
@@ -40,13 +41,14 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \
StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel
import diffusers
from diffusers import \
AutoencoderKL, \
UNet2DConditionModel
from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler, PixArtSigmaPipeline
from transformers import T5EncoderModel, BitsAndBytesConfig
from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from toolkit.util.inverse_cfg import inverse_classifier_guidance
@@ -149,6 +151,7 @@ class StableDiffusion:
self.is_v3 = model_config.is_v3
self.is_vega = model_config.is_vega
self.is_pixart = model_config.is_pixart
self.is_auraflow = model_config.is_auraflow
self.use_text_encoder_1 = model_config.use_text_encoder_1
self.use_text_encoder_2 = model_config.use_text_encoder_2
@@ -371,6 +374,68 @@ class StableDiffusion:
text_encoder.eval()
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
tokenizer = pipe.tokenizer
elif self.model_config.is_auraflow:
te_kwargs = {}
# handle quantization of TE
te_is_quantized = False
if self.model_config.text_encoder_bits == 8:
te_kwargs['load_in_8bit'] = True
te_kwargs['device_map'] = "auto"
te_is_quantized = True
elif self.model_config.text_encoder_bits == 4:
te_kwargs['load_in_4bit'] = True
te_kwargs['device_map'] = "auto"
te_is_quantized = True
main_model_path = model_path
# load the TE in 8bit mode
text_encoder = UMT5EncoderModel.from_pretrained(
main_model_path,
subfolder="text_encoder",
torch_dtype=self.torch_dtype,
**te_kwargs
)
# load the transformer
subfolder = "transformer"
# check if it is just the unet
if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)):
subfolder = None
if te_is_quantized:
# replace the to function with a no-op since it throws an error instead of a warning
text_encoder.to = lambda *args, **kwargs: None
# load the transformer only from the save
transformer = AuraFlowTransformer2DModel.from_pretrained(
model_path if self.model_config.unet_path is None else self.model_config.unet_path,
torch_dtype=self.torch_dtype,
subfolder='transformer'
)
pipe: AuraFlowPipeline = AuraFlowPipeline.from_pretrained(
main_model_path,
transformer=transformer,
text_encoder=text_encoder,
dtype=dtype,
device=self.device_torch,
**load_args
)
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
# patch auraflow so it can handle other aspect ratios
patch_auraflow_pos_embed(pipe.transformer.pos_embed)
flush()
# text_encoder = pipe.text_encoder
# text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
tokenizer = pipe.tokenizer
else:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
@@ -418,7 +483,7 @@ class StableDiffusion:
# add hacks to unet to help training
# pipe.unet = prepare_unet_for_training(pipe.unet)
if self.is_pixart or self.is_v3:
if self.is_pixart or self.is_v3 or self.is_auraflow:
# pixart and sd3 dont use a unet
self.unet = pipe.transformer
else:
@@ -621,6 +686,16 @@ class StableDiffusion:
**extra_args
)
elif self.is_auraflow:
pipeline = AuraFlowPipeline(
vae=self.vae,
transformer=self.unet,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=noise_scheduler,
**extra_args
)
else:
pipeline = Pipe(
vae=self.vae,
@@ -846,6 +921,24 @@ class StableDiffusion:
).images[0]
elif self.is_pixart:
# needs attention masks for some reason
img = pipeline(
prompt=None,
prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype),
negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype),
negative_prompt=None,
# negative_prompt=gen_config.negative_prompt,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
**extra
).images[0]
elif self.is_auraflow:
pipeline: AuraFlowPipeline = pipeline
img = pipeline(
prompt=None,
prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
@@ -1309,6 +1402,18 @@ class StableDiffusion:
**kwargs,
).sample
noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep)
elif self.is_auraflow:
# aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
t = torch.tensor([timestep / 1000]).expand(latent_model_input.shape[0])
t = t.to(self.device_torch, self.torch_dtype)
noise_pred = self.unet(
latent_model_input,
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
timestep=t,
return_dict=False,
)[0]
else:
noise_pred = self.unet(
latent_model_input.to(self.device_torch, self.torch_dtype),
@@ -1502,6 +1607,19 @@ class StableDiffusion:
embeds,
attention_mask=attention_mask,
)
elif self.is_auraflow:
embeds, attention_mask = train_tools.encode_prompts_auraflow(
self.tokenizer,
self.text_encoder,
prompt,
truncate=not long_prompts,
max_length=256,
dropout_prob=dropout_prob
)
return PromptEmbeds(
embeds,
attention_mask=attention_mask, # not used
)
elif isinstance(self.text_encoder, T5EncoderModel):
embeds, attention_mask = train_tools.encode_prompts_pixart(
@@ -1835,7 +1953,7 @@ class StableDiffusion:
named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True)
unet_lr = unet_lr if unet_lr is not None else default_lr
params = []
if self.is_pixart:
if self.is_pixart or self.is_auraflow:
for param in named_params.values():
if param.requires_grad:
params.append(param)
@@ -1881,7 +1999,7 @@ class StableDiffusion:
def save_device_state(self):
# saves the current device state for all modules
# this is useful for when we want to alter the state and restore it
if self.is_pixart or self.is_v3:
if self.is_pixart or self.is_v3 or self.is_auraflow:
unet_has_grad = self.unet.proj_out.weight.requires_grad
else:
unet_has_grad = self.unet.conv_in.weight.requires_grad
@@ -1912,7 +2030,7 @@ class StableDiffusion:
'requires_grad': te_has_grad
})
else:
if isinstance(self.text_encoder, T5EncoderModel):
if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel):
te_has_grad = self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
else:
te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad

View File

@@ -30,7 +30,7 @@ from diffusers import (
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
import torch
import re
from transformers import T5Tokenizer, T5EncoderModel
from transformers import T5Tokenizer, T5EncoderModel, UMT5EncoderModel
SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120
@@ -725,6 +725,48 @@ def encode_prompts_pixart(
return prompt_embeds.last_hidden_state, prompt_attention_mask
def encode_prompts_auraflow(
tokenizer: 'T5Tokenizer',
text_encoder: 'UMT5EncoderModel',
prompts: list[str],
truncate: bool = True,
max_length=None,
dropout_prob=0.0,
):
if max_length is None:
max_length = 256
if dropout_prob > 0.0:
# randomly drop out prompts
prompts = [
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts
]
device = text_encoder.device
text_inputs = tokenizer(
prompts,
truncation=True,
max_length=max_length,
padding="max_length",
return_tensors="pt",
)
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
text_input_ids = text_inputs["input_ids"]
untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_length - 1: -1])
prompt_embeds = text_encoder(**text_inputs)[0]
prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
prompt_embeds = prompt_embeds * prompt_attention_mask
return prompt_embeds, prompt_attention_mask
# for XL
def get_add_time_ids(
height: int,