mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-04 12:39:58 +00:00
Partial implementation for training auraflow.
This commit is contained in:
@@ -883,6 +883,26 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
def end_of_training_loop(self):
|
||||
pass
|
||||
|
||||
def predict_noise(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
timesteps: Union[int, torch.Tensor] = 1,
|
||||
conditional_embeds: Union[PromptEmbeds, None] = None,
|
||||
unconditional_embeds: Union[PromptEmbeds, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
return self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeddings=unconditional_embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=self.train_config.cfg_scale,
|
||||
detach_unconditional=False,
|
||||
rescale_cfg=self.train_config.cfg_rescale,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
|
||||
self.timer.start('preprocess_batch')
|
||||
batch = self.preprocess_batch(batch)
|
||||
@@ -1453,14 +1473,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
with self.timer('predict_unet'):
|
||||
if unconditional_embeds is not None:
|
||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeddings=unconditional_embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=self.train_config.cfg_scale,
|
||||
detach_unconditional=False,
|
||||
rescale_cfg=self.train_config.cfg_rescale,
|
||||
noise_pred = self.predict_noise(
|
||||
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
timesteps=timesteps,
|
||||
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
**pred_kwargs
|
||||
)
|
||||
self.after_unet_predict()
|
||||
|
||||
@@ -1287,6 +1287,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
is_v2=self.model_config.is_v2,
|
||||
is_v3=self.model_config.is_v3,
|
||||
is_pixart=self.model_config.is_pixart,
|
||||
is_auraflow=self.model_config.is_auraflow,
|
||||
is_ssd=self.model_config.is_ssd,
|
||||
is_vega=self.model_config.is_vega,
|
||||
dropout=self.network_config.dropout,
|
||||
|
||||
@@ -53,6 +53,50 @@ resolutions_1024: List[BucketResolution] = [
|
||||
{"width": 512, "height": 2048},
|
||||
]
|
||||
|
||||
# Even numbers so they can be patched easier
|
||||
resolutions_dit_1024: List[BucketResolution] = [
|
||||
# Base resolution
|
||||
{"width": 1024, "height": 1024},
|
||||
# widescreen
|
||||
{"width": 2048, "height": 512},
|
||||
{"width": 1792, "height": 576},
|
||||
{"width": 1728, "height": 576},
|
||||
{"width": 1664, "height": 576},
|
||||
{"width": 1600, "height": 640},
|
||||
{"width": 1536, "height": 640},
|
||||
{"width": 1472, "height": 704},
|
||||
{"width": 1408, "height": 704},
|
||||
{"width": 1344, "height": 704},
|
||||
{"width": 1344, "height": 768},
|
||||
{"width": 1280, "height": 768},
|
||||
{"width": 1216, "height": 832},
|
||||
{"width": 1152, "height": 832},
|
||||
{"width": 1152, "height": 896},
|
||||
{"width": 1088, "height": 896},
|
||||
{"width": 1088, "height": 960},
|
||||
{"width": 1024, "height": 960},
|
||||
# portrait
|
||||
{"width": 960, "height": 1024},
|
||||
{"width": 960, "height": 1088},
|
||||
{"width": 896, "height": 1088},
|
||||
{"width": 896, "height": 1152}, # 2:3
|
||||
{"width": 832, "height": 1152},
|
||||
{"width": 832, "height": 1216},
|
||||
{"width": 768, "height": 1280},
|
||||
{"width": 768, "height": 1344},
|
||||
{"width": 704, "height": 1408},
|
||||
{"width": 704, "height": 1472},
|
||||
{"width": 640, "height": 1536},
|
||||
{"width": 640, "height": 1600},
|
||||
{"width": 576, "height": 1664},
|
||||
{"width": 576, "height": 1728},
|
||||
{"width": 576, "height": 1792},
|
||||
{"width": 512, "height": 1856},
|
||||
{"width": 512, "height": 1920},
|
||||
{"width": 512, "height": 1984},
|
||||
{"width": 512, "height": 2048},
|
||||
]
|
||||
|
||||
|
||||
def get_bucket_sizes(resolution: int = 512, divisibility: int = 8) -> List[BucketResolution]:
|
||||
# determine scaler form 1024 to resolution
|
||||
|
||||
@@ -350,6 +350,7 @@ class ModelConfig:
|
||||
self.is_xl: bool = kwargs.get('is_xl', False)
|
||||
self.is_pixart: bool = kwargs.get('is_pixart', False)
|
||||
self.is_pixart_sigma: bool = kwargs.get('is_pixart_sigma', False)
|
||||
self.is_auraflow: bool = kwargs.get('is_auraflow', False)
|
||||
self.is_v3: bool = kwargs.get('is_v3', False)
|
||||
if self.is_pixart_sigma:
|
||||
self.is_pixart = True
|
||||
@@ -381,7 +382,7 @@ class ModelConfig:
|
||||
self.is_xl = True
|
||||
|
||||
# for text encoder quant. Only works with pixart currently
|
||||
self.text_encoder_bits = kwargs.get('text_encoder_bits', 8) # 16, 8, 4
|
||||
self.text_encoder_bits = kwargs.get('text_encoder_bits', 16) # 16, 8, 4
|
||||
self.unet_path = kwargs.get("unet_path", None)
|
||||
self.unet_sample_size = kwargs.get("unet_sample_size", None)
|
||||
|
||||
|
||||
@@ -1355,6 +1355,10 @@ class LatentCachingMixin:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
elif self.sd.is_v3:
|
||||
file_item.latent_space_version = 'sd3'
|
||||
elif self.sd.is_auraflow:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
elif self.sd.model_config.is_pixart_sigma:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
else:
|
||||
file_item.latent_space_version = 'sd1'
|
||||
file_item.is_caching_to_disk = to_disk
|
||||
|
||||
@@ -7,7 +7,7 @@ import re
|
||||
import sys
|
||||
from typing import List, Optional, Dict, Type, Union
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel
|
||||
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from .config_modules import NetworkConfig
|
||||
@@ -158,6 +158,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
is_v2=False,
|
||||
is_v3=False,
|
||||
is_pixart: bool = False,
|
||||
is_auraflow: bool = False,
|
||||
use_bias: bool = False,
|
||||
is_lorm: bool = False,
|
||||
ignore_if_contains = None,
|
||||
@@ -212,6 +213,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.is_v2 = is_v2
|
||||
self.is_v3 = is_v3
|
||||
self.is_pixart = is_pixart
|
||||
self.is_auraflow = is_auraflow
|
||||
self.network_type = network_type
|
||||
if self.network_type.lower() == "dora":
|
||||
self.module_class = DoRAModule
|
||||
@@ -246,7 +248,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
target_replace_modules: List[torch.nn.Module],
|
||||
) -> List[LoRAModule]:
|
||||
unet_prefix = self.LORA_PREFIX_UNET
|
||||
if is_pixart or is_v3:
|
||||
if is_pixart or is_v3 or is_auraflow:
|
||||
unet_prefix = f"lora_transformer"
|
||||
|
||||
prefix = (
|
||||
@@ -371,6 +373,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
if is_pixart:
|
||||
target_modules = ["PixArtTransformer2DModel"]
|
||||
|
||||
if is_auraflow:
|
||||
target_modules = ["AuraFlowTransformer2DModel"]
|
||||
|
||||
if train_unet:
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
else:
|
||||
@@ -408,6 +413,14 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
transformer.pos_embed = self.transformer_pos_embed
|
||||
transformer.proj_out = self.transformer_proj_out
|
||||
|
||||
elif self.is_auraflow:
|
||||
transformer: AuraFlowTransformer2DModel = unet
|
||||
self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed)
|
||||
self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
|
||||
|
||||
transformer.pos_embed = self.transformer_pos_embed
|
||||
transformer.proj_out = self.transformer_proj_out
|
||||
|
||||
else:
|
||||
unet: UNet2DConditionModel = unet
|
||||
unet_conv_in: torch.nn.Conv2d = unet.conv_in
|
||||
@@ -424,7 +437,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)
|
||||
|
||||
if self.full_train_in_out:
|
||||
if self.is_pixart:
|
||||
if self.is_pixart or self.is_auraflow:
|
||||
all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
|
||||
all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
|
||||
else:
|
||||
|
||||
127
toolkit/models/auraflow.py
Normal file
127
toolkit/models/auraflow.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
|
||||
class AuraFlowPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
height=224,
|
||||
width=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
embed_dim=768,
|
||||
pos_embed_max_size=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_patches = (height // patch_size) * (width // patch_size)
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
|
||||
self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim)
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1)
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.height, self.width = height // patch_size, width // patch_size
|
||||
self.base_size = height // patch_size
|
||||
|
||||
def forward(self, latent):
|
||||
batch_size, num_channels, height, width = latent.size()
|
||||
latent = latent.view(
|
||||
batch_size,
|
||||
num_channels,
|
||||
height // self.patch_size,
|
||||
self.patch_size,
|
||||
width // self.patch_size,
|
||||
self.patch_size,
|
||||
)
|
||||
latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
|
||||
latent = self.proj(latent)
|
||||
try:
|
||||
return latent + self.pos_embed
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Positional embeddings are too small for the number of patches. "
|
||||
f"Please increase `pos_embed_max_size` to at least {self.num_patches}."
|
||||
)
|
||||
|
||||
|
||||
# comfy
|
||||
# def apply_pos_embeds(self, x, h, w):
|
||||
# h = (h + 1) // self.patch_size
|
||||
# w = (w + 1) // self.patch_size
|
||||
# max_dim = max(h, w)
|
||||
#
|
||||
# cur_dim = self.h_max
|
||||
# pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype)
|
||||
#
|
||||
# if max_dim > cur_dim:
|
||||
# pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1,
|
||||
# -1)
|
||||
# cur_dim = max_dim
|
||||
#
|
||||
# from_h = (cur_dim - h) // 2
|
||||
# from_w = (cur_dim - w) // 2
|
||||
# pos_encoding = pos_encoding[:, from_h:from_h + h, from_w:from_w + w]
|
||||
# return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
||||
|
||||
# def patchify(self, x):
|
||||
# B, C, H, W = x.size()
|
||||
# pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
|
||||
# pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
|
||||
#
|
||||
# x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
|
||||
# x = x.view(
|
||||
# B,
|
||||
# C,
|
||||
# (H + 1) // self.patch_size,
|
||||
# self.patch_size,
|
||||
# (W + 1) // self.patch_size,
|
||||
# self.patch_size,
|
||||
# )
|
||||
# x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
|
||||
# return x
|
||||
|
||||
def patch_auraflow_pos_embed(pos_embed):
|
||||
# we need to hijack the forward and replace with a custom one. Self is the model
|
||||
def new_forward(self, latent):
|
||||
batch_size, num_channels, height, width = latent.size()
|
||||
|
||||
# add padding to the latent to make it match pos_embed
|
||||
latent_size = height * width * num_channels / 16 # todo check where 16 comes from?
|
||||
pos_embed_size = self.pos_embed.shape[1]
|
||||
if latent_size < pos_embed_size:
|
||||
total_padding = int(pos_embed_size - math.floor(latent_size))
|
||||
total_padding = total_padding // 16
|
||||
pad_height = total_padding // 2
|
||||
pad_width = total_padding - pad_height
|
||||
# mirror padding on the right side
|
||||
padding = (0, pad_width, 0, pad_height)
|
||||
latent = torch.nn.functional.pad(latent, padding, mode='reflect')
|
||||
elif latent_size > pos_embed_size:
|
||||
amount_to_remove = latent_size - pos_embed_size
|
||||
latent = latent[:, :, :-amount_to_remove]
|
||||
|
||||
batch_size, num_channels, height, width = latent.size()
|
||||
|
||||
latent = latent.view(
|
||||
batch_size,
|
||||
num_channels,
|
||||
height // self.patch_size,
|
||||
self.patch_size,
|
||||
width // self.patch_size,
|
||||
self.patch_size,
|
||||
)
|
||||
latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
|
||||
latent = self.proj(latent)
|
||||
try:
|
||||
return latent + self.pos_embed
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Positional embeddings are too small for the number of patches. "
|
||||
f"Please increase `pos_embed_max_size` to at least {self.num_patches}."
|
||||
)
|
||||
|
||||
pos_embed.forward = partial(new_forward, pos_embed)
|
||||
@@ -27,6 +27,7 @@ from library.model_util import convert_unet_state_dict_to_sd, convert_text_encod
|
||||
from toolkit import train_tools
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.models.auraflow import patch_auraflow_pos_embed
|
||||
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
@@ -40,13 +41,14 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
||||
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline
|
||||
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \
|
||||
StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel
|
||||
import diffusers
|
||||
from diffusers import \
|
||||
AutoencoderKL, \
|
||||
UNet2DConditionModel
|
||||
from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler, PixArtSigmaPipeline
|
||||
from transformers import T5EncoderModel, BitsAndBytesConfig
|
||||
from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel
|
||||
|
||||
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
||||
from toolkit.util.inverse_cfg import inverse_classifier_guidance
|
||||
@@ -149,6 +151,7 @@ class StableDiffusion:
|
||||
self.is_v3 = model_config.is_v3
|
||||
self.is_vega = model_config.is_vega
|
||||
self.is_pixart = model_config.is_pixart
|
||||
self.is_auraflow = model_config.is_auraflow
|
||||
|
||||
self.use_text_encoder_1 = model_config.use_text_encoder_1
|
||||
self.use_text_encoder_2 = model_config.use_text_encoder_2
|
||||
@@ -371,6 +374,68 @@ class StableDiffusion:
|
||||
text_encoder.eval()
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
||||
tokenizer = pipe.tokenizer
|
||||
|
||||
|
||||
elif self.model_config.is_auraflow:
|
||||
te_kwargs = {}
|
||||
# handle quantization of TE
|
||||
te_is_quantized = False
|
||||
if self.model_config.text_encoder_bits == 8:
|
||||
te_kwargs['load_in_8bit'] = True
|
||||
te_kwargs['device_map'] = "auto"
|
||||
te_is_quantized = True
|
||||
elif self.model_config.text_encoder_bits == 4:
|
||||
te_kwargs['load_in_4bit'] = True
|
||||
te_kwargs['device_map'] = "auto"
|
||||
te_is_quantized = True
|
||||
|
||||
main_model_path = model_path
|
||||
|
||||
# load the TE in 8bit mode
|
||||
text_encoder = UMT5EncoderModel.from_pretrained(
|
||||
main_model_path,
|
||||
subfolder="text_encoder",
|
||||
torch_dtype=self.torch_dtype,
|
||||
**te_kwargs
|
||||
)
|
||||
|
||||
# load the transformer
|
||||
subfolder = "transformer"
|
||||
# check if it is just the unet
|
||||
if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)):
|
||||
subfolder = None
|
||||
|
||||
if te_is_quantized:
|
||||
# replace the to function with a no-op since it throws an error instead of a warning
|
||||
text_encoder.to = lambda *args, **kwargs: None
|
||||
|
||||
# load the transformer only from the save
|
||||
transformer = AuraFlowTransformer2DModel.from_pretrained(
|
||||
model_path if self.model_config.unet_path is None else self.model_config.unet_path,
|
||||
torch_dtype=self.torch_dtype,
|
||||
subfolder='transformer'
|
||||
)
|
||||
pipe: AuraFlowPipeline = AuraFlowPipeline.from_pretrained(
|
||||
main_model_path,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
dtype=dtype,
|
||||
device=self.device_torch,
|
||||
**load_args
|
||||
)
|
||||
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
# patch auraflow so it can handle other aspect ratios
|
||||
patch_auraflow_pos_embed(pipe.transformer.pos_embed)
|
||||
|
||||
flush()
|
||||
# text_encoder = pipe.text_encoder
|
||||
# text_encoder.to(self.device_torch, dtype=dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
||||
tokenizer = pipe.tokenizer
|
||||
else:
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
@@ -418,7 +483,7 @@ class StableDiffusion:
|
||||
# add hacks to unet to help training
|
||||
# pipe.unet = prepare_unet_for_training(pipe.unet)
|
||||
|
||||
if self.is_pixart or self.is_v3:
|
||||
if self.is_pixart or self.is_v3 or self.is_auraflow:
|
||||
# pixart and sd3 dont use a unet
|
||||
self.unet = pipe.transformer
|
||||
else:
|
||||
@@ -621,6 +686,16 @@ class StableDiffusion:
|
||||
**extra_args
|
||||
)
|
||||
|
||||
elif self.is_auraflow:
|
||||
pipeline = AuraFlowPipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.unet,
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
scheduler=noise_scheduler,
|
||||
**extra_args
|
||||
)
|
||||
|
||||
else:
|
||||
pipeline = Pipe(
|
||||
vae=self.vae,
|
||||
@@ -846,6 +921,24 @@ class StableDiffusion:
|
||||
).images[0]
|
||||
elif self.is_pixart:
|
||||
# needs attention masks for some reason
|
||||
img = pipeline(
|
||||
prompt=None,
|
||||
prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
|
||||
prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype),
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
|
||||
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype),
|
||||
negative_prompt=None,
|
||||
# negative_prompt=gen_config.negative_prompt,
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
**extra
|
||||
).images[0]
|
||||
elif self.is_auraflow:
|
||||
pipeline: AuraFlowPipeline = pipeline
|
||||
|
||||
img = pipeline(
|
||||
prompt=None,
|
||||
prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
|
||||
@@ -1309,6 +1402,18 @@ class StableDiffusion:
|
||||
**kwargs,
|
||||
).sample
|
||||
noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep)
|
||||
elif self.is_auraflow:
|
||||
# aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
t = torch.tensor([timestep / 1000]).expand(latent_model_input.shape[0])
|
||||
t = t.to(self.device_torch, self.torch_dtype)
|
||||
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
|
||||
timestep=t,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
noise_pred = self.unet(
|
||||
latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
@@ -1502,6 +1607,19 @@ class StableDiffusion:
|
||||
embeds,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
elif self.is_auraflow:
|
||||
embeds, attention_mask = train_tools.encode_prompts_auraflow(
|
||||
self.tokenizer,
|
||||
self.text_encoder,
|
||||
prompt,
|
||||
truncate=not long_prompts,
|
||||
max_length=256,
|
||||
dropout_prob=dropout_prob
|
||||
)
|
||||
return PromptEmbeds(
|
||||
embeds,
|
||||
attention_mask=attention_mask, # not used
|
||||
)
|
||||
|
||||
elif isinstance(self.text_encoder, T5EncoderModel):
|
||||
embeds, attention_mask = train_tools.encode_prompts_pixart(
|
||||
@@ -1835,7 +1953,7 @@ class StableDiffusion:
|
||||
named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True)
|
||||
unet_lr = unet_lr if unet_lr is not None else default_lr
|
||||
params = []
|
||||
if self.is_pixart:
|
||||
if self.is_pixart or self.is_auraflow:
|
||||
for param in named_params.values():
|
||||
if param.requires_grad:
|
||||
params.append(param)
|
||||
@@ -1881,7 +1999,7 @@ class StableDiffusion:
|
||||
def save_device_state(self):
|
||||
# saves the current device state for all modules
|
||||
# this is useful for when we want to alter the state and restore it
|
||||
if self.is_pixart or self.is_v3:
|
||||
if self.is_pixart or self.is_v3 or self.is_auraflow:
|
||||
unet_has_grad = self.unet.proj_out.weight.requires_grad
|
||||
else:
|
||||
unet_has_grad = self.unet.conv_in.weight.requires_grad
|
||||
@@ -1912,7 +2030,7 @@ class StableDiffusion:
|
||||
'requires_grad': te_has_grad
|
||||
})
|
||||
else:
|
||||
if isinstance(self.text_encoder, T5EncoderModel):
|
||||
if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel):
|
||||
te_has_grad = self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
|
||||
else:
|
||||
te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad
|
||||
|
||||
@@ -30,7 +30,7 @@ from diffusers import (
|
||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||
import torch
|
||||
import re
|
||||
from transformers import T5Tokenizer, T5EncoderModel
|
||||
from transformers import T5Tokenizer, T5EncoderModel, UMT5EncoderModel
|
||||
|
||||
SCHEDULER_LINEAR_START = 0.00085
|
||||
SCHEDULER_LINEAR_END = 0.0120
|
||||
@@ -725,6 +725,48 @@ def encode_prompts_pixart(
|
||||
return prompt_embeds.last_hidden_state, prompt_attention_mask
|
||||
|
||||
|
||||
def encode_prompts_auraflow(
|
||||
tokenizer: 'T5Tokenizer',
|
||||
text_encoder: 'UMT5EncoderModel',
|
||||
prompts: list[str],
|
||||
truncate: bool = True,
|
||||
max_length=None,
|
||||
dropout_prob=0.0,
|
||||
):
|
||||
if max_length is None:
|
||||
max_length = 256
|
||||
|
||||
if dropout_prob > 0.0:
|
||||
# randomly drop out prompts
|
||||
prompts = [
|
||||
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts
|
||||
]
|
||||
|
||||
device = text_encoder.device
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompts,
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
||||
text_input_ids = text_inputs["input_ids"]
|
||||
untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_length - 1: -1])
|
||||
|
||||
prompt_embeds = text_encoder(**text_inputs)[0]
|
||||
prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
|
||||
prompt_embeds = prompt_embeds * prompt_attention_mask
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
|
||||
# for XL
|
||||
def get_add_time_ids(
|
||||
height: int,
|
||||
|
||||
Reference in New Issue
Block a user