Partial implementation for training auraflow.

This commit is contained in:
Jaret Burkett
2024-07-12 12:11:38 -06:00
parent c062b7716c
commit e4558dff4b
9 changed files with 386 additions and 19 deletions

View File

@@ -27,6 +27,7 @@ from library.model_util import convert_unet_state_dict_to_sd, convert_text_encod
from toolkit import train_tools
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors
from toolkit.models.auraflow import patch_auraflow_pos_embed
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
from toolkit.reference_adapter import ReferenceAdapter
@@ -40,13 +41,14 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \
StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel
import diffusers
from diffusers import \
AutoencoderKL, \
UNet2DConditionModel
from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler, PixArtSigmaPipeline
from transformers import T5EncoderModel, BitsAndBytesConfig
from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from toolkit.util.inverse_cfg import inverse_classifier_guidance
@@ -149,6 +151,7 @@ class StableDiffusion:
self.is_v3 = model_config.is_v3
self.is_vega = model_config.is_vega
self.is_pixart = model_config.is_pixart
self.is_auraflow = model_config.is_auraflow
self.use_text_encoder_1 = model_config.use_text_encoder_1
self.use_text_encoder_2 = model_config.use_text_encoder_2
@@ -371,6 +374,68 @@ class StableDiffusion:
text_encoder.eval()
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
tokenizer = pipe.tokenizer
elif self.model_config.is_auraflow:
te_kwargs = {}
# handle quantization of TE
te_is_quantized = False
if self.model_config.text_encoder_bits == 8:
te_kwargs['load_in_8bit'] = True
te_kwargs['device_map'] = "auto"
te_is_quantized = True
elif self.model_config.text_encoder_bits == 4:
te_kwargs['load_in_4bit'] = True
te_kwargs['device_map'] = "auto"
te_is_quantized = True
main_model_path = model_path
# load the TE in 8bit mode
text_encoder = UMT5EncoderModel.from_pretrained(
main_model_path,
subfolder="text_encoder",
torch_dtype=self.torch_dtype,
**te_kwargs
)
# load the transformer
subfolder = "transformer"
# check if it is just the unet
if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)):
subfolder = None
if te_is_quantized:
# replace the to function with a no-op since it throws an error instead of a warning
text_encoder.to = lambda *args, **kwargs: None
# load the transformer only from the save
transformer = AuraFlowTransformer2DModel.from_pretrained(
model_path if self.model_config.unet_path is None else self.model_config.unet_path,
torch_dtype=self.torch_dtype,
subfolder='transformer'
)
pipe: AuraFlowPipeline = AuraFlowPipeline.from_pretrained(
main_model_path,
transformer=transformer,
text_encoder=text_encoder,
dtype=dtype,
device=self.device_torch,
**load_args
)
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
# patch auraflow so it can handle other aspect ratios
patch_auraflow_pos_embed(pipe.transformer.pos_embed)
flush()
# text_encoder = pipe.text_encoder
# text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
tokenizer = pipe.tokenizer
else:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
@@ -418,7 +483,7 @@ class StableDiffusion:
# add hacks to unet to help training
# pipe.unet = prepare_unet_for_training(pipe.unet)
if self.is_pixart or self.is_v3:
if self.is_pixart or self.is_v3 or self.is_auraflow:
# pixart and sd3 dont use a unet
self.unet = pipe.transformer
else:
@@ -621,6 +686,16 @@ class StableDiffusion:
**extra_args
)
elif self.is_auraflow:
pipeline = AuraFlowPipeline(
vae=self.vae,
transformer=self.unet,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=noise_scheduler,
**extra_args
)
else:
pipeline = Pipe(
vae=self.vae,
@@ -846,6 +921,24 @@ class StableDiffusion:
).images[0]
elif self.is_pixart:
# needs attention masks for some reason
img = pipeline(
prompt=None,
prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype),
negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype),
negative_prompt=None,
# negative_prompt=gen_config.negative_prompt,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
**extra
).images[0]
elif self.is_auraflow:
pipeline: AuraFlowPipeline = pipeline
img = pipeline(
prompt=None,
prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
@@ -1309,6 +1402,18 @@ class StableDiffusion:
**kwargs,
).sample
noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep)
elif self.is_auraflow:
# aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
t = torch.tensor([timestep / 1000]).expand(latent_model_input.shape[0])
t = t.to(self.device_torch, self.torch_dtype)
noise_pred = self.unet(
latent_model_input,
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
timestep=t,
return_dict=False,
)[0]
else:
noise_pred = self.unet(
latent_model_input.to(self.device_torch, self.torch_dtype),
@@ -1502,6 +1607,19 @@ class StableDiffusion:
embeds,
attention_mask=attention_mask,
)
elif self.is_auraflow:
embeds, attention_mask = train_tools.encode_prompts_auraflow(
self.tokenizer,
self.text_encoder,
prompt,
truncate=not long_prompts,
max_length=256,
dropout_prob=dropout_prob
)
return PromptEmbeds(
embeds,
attention_mask=attention_mask, # not used
)
elif isinstance(self.text_encoder, T5EncoderModel):
embeds, attention_mask = train_tools.encode_prompts_pixart(
@@ -1835,7 +1953,7 @@ class StableDiffusion:
named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True)
unet_lr = unet_lr if unet_lr is not None else default_lr
params = []
if self.is_pixart:
if self.is_pixart or self.is_auraflow:
for param in named_params.values():
if param.requires_grad:
params.append(param)
@@ -1881,7 +1999,7 @@ class StableDiffusion:
def save_device_state(self):
# saves the current device state for all modules
# this is useful for when we want to alter the state and restore it
if self.is_pixart or self.is_v3:
if self.is_pixart or self.is_v3 or self.is_auraflow:
unet_has_grad = self.unet.proj_out.weight.requires_grad
else:
unet_has_grad = self.unet.conv_in.weight.requires_grad
@@ -1912,7 +2030,7 @@ class StableDiffusion:
'requires_grad': te_has_grad
})
else:
if isinstance(self.text_encoder, T5EncoderModel):
if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel):
te_has_grad = self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
else:
te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad