Added flux training. Still a WIP. Wont train right without rectified flow working right

This commit is contained in:
Jaret Burkett
2024-08-02 15:00:30 -06:00
parent 03613c523f
commit 87ba867fdc
6 changed files with 292 additions and 15 deletions

View File

@@ -41,17 +41,21 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAda
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \
StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel
StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \
FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler
import diffusers
from diffusers import \
AutoencoderKL, \
UNet2DConditionModel
from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler, PixArtSigmaPipeline
from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel
from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel, T5TokenizerFast
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from toolkit.util.inverse_cfg import inverse_classifier_guidance
from optimum.quanto import freeze, qfloat8, quantize
# tell it to shut up
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
@@ -78,6 +82,7 @@ DO_NOT_TRAIN_WEIGHTS = [
DeviceStatePreset = Literal['cache_latents', 'generate']
class BlankNetwork:
def __init__(self):
@@ -101,10 +106,6 @@ def flush():
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
# VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
# if is type checking
if typing.TYPE_CHECKING:
from diffusers.schedulers import KarrasDiffusionSchedulers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
class StableDiffusion:
@@ -158,6 +159,7 @@ class StableDiffusion:
self.is_vega = model_config.is_vega
self.is_pixart = model_config.is_pixart
self.is_auraflow = model_config.is_auraflow
self.is_flux = model_config.is_flux
self.use_text_encoder_1 = model_config.use_text_encoder_1
self.use_text_encoder_2 = model_config.use_text_encoder_2
@@ -443,6 +445,71 @@ class StableDiffusion:
text_encoder.eval()
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
tokenizer = pipe.tokenizer
elif self.model_config.is_flux:
print("Loading Flux model")
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
print("Loading vae")
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype)
flush()
print("Loading transformer")
transformer = FluxTransformer2DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype)
transformer.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize:
print("Quantizing transformer")
quantize(transformer, weights=qfloat8)
freeze(transformer)
flush()
print("Loading t5")
text_encoder_2 = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder_2", torch_dtype=dtype)
tokenizer_2 = T5TokenizerFast.from_pretrained(model_path, subfolder="tokenizer_2", torch_dtype=dtype)
text_encoder_2.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize:
print("Quantizing T5")
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
flush()
print("Loading clip")
text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder.to(self.device_torch, dtype=dtype)
print("making pipe")
pipe = FluxPipeline(
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=None,
tokenizer_2=tokenizer_2,
vae=vae,
transformer=None,
)
pipe.text_encoder_2 = text_encoder_2
pipe.transformer = transformer
print("preparing")
text_encoder = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
text_encoder[0].to(self.device_torch)
text_encoder[0].requires_grad_(False)
text_encoder[0].eval()
text_encoder[1].to(self.device_torch)
text_encoder[1].requires_grad_(False)
text_encoder[1].eval()
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
else:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
@@ -515,7 +582,7 @@ class StableDiffusion:
# add hacks to unet to help training
# pipe.unet = prepare_unet_for_training(pipe.unet)
if self.is_pixart or self.is_v3 or self.is_auraflow:
if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
# pixart and sd3 dont use a unet
self.unet = pipe.transformer
else:
@@ -695,6 +762,18 @@ class StableDiffusion:
**extra_args
).to(self.device_torch)
pipeline.watermark = None
elif self.is_flux:
pipeline = FluxPipeline(
vae=self.vae,
transformer=self.unet,
text_encoder=self.text_encoder[0],
text_encoder_2=self.text_encoder[1],
tokenizer=self.tokenizer[0],
tokenizer_2=self.tokenizer[1],
scheduler=noise_scheduler,
**extra_args
).to(self.device_torch)
pipeline.watermark = None
elif self.is_v3:
pipeline = Pipe(
vae=self.vae,
@@ -954,6 +1033,19 @@ class StableDiffusion:
latents=gen_config.latents,
**extra
).images[0]
elif self.is_flux:
img = pipeline(
prompt_embeds=conditional_embeds.text_embeds,
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
# negative_prompt_embeds=unconditional_embeds.text_embeds,
# negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
**extra
).images[0]
elif self.is_pixart:
# needs attention masks for some reason
img = pipeline(
@@ -1073,10 +1165,14 @@ class StableDiffusion:
if width is None:
width = pixel_width // VAE_SCALE_FACTOR
num_channels = self.unet.config['in_channels']
if self.is_flux:
# has 64 channels in for some reason
num_channels = 16
noise = torch.randn(
(
batch_size,
self.unet.config['in_channels'],
num_channels,
height,
width,
),
@@ -1429,7 +1525,88 @@ class StableDiffusion:
self.unet.to(self.device_torch)
if self.unet.dtype != self.torch_dtype:
self.unet = self.unet.to(dtype=self.torch_dtype)
if self.is_v3:
if self.is_flux:
with torch.no_grad():
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) # 16 . Maybe dont subtract
# this is what diffusers does
text_ids = torch.zeros(latent_model_input.shape[0], text_embeddings.text_embeds.shape[1], 3).to(
device=self.device_torch, dtype=self.text_encoder[0].dtype
)
# todo check these
# height = latent_model_input.shape[2] * VAE_SCALE_FACTOR
# width = latent_model_input.shape[3] * VAE_SCALE_FACTOR
height = latent_model_input.shape[2] * VAE_SCALE_FACTOR # 128
width = latent_model_input.shape[3] * VAE_SCALE_FACTOR # 128
width_latent = latent_model_input.shape[3]
height_latent = latent_model_input.shape[2]
latent_image_ids = self.pipeline._prepare_latent_image_ids(
batch_size=latent_model_input.shape[0],
height=height_latent,
width=width_latent,
device=self.device_torch,
dtype=self.torch_dtype,
)
# # handle guidance
guidance_scale = 1.0 # ?
if self.unet.config.guidance_embeds:
guidance = torch.tensor([guidance_scale], device=self.device_torch)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
# not sure how to handle this
# sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
# image_seq_len = latents.shape[1]
# mu = calculate_shift(
# image_seq_len,
# self.scheduler.config.base_image_seq_len,
# self.scheduler.config.max_image_seq_len,
# self.scheduler.config.base_shift,
# self.scheduler.config.max_shift,
# )
# timesteps, num_inference_steps = retrieve_timesteps(
# self.scheduler,
# num_inference_steps,
# device,
# timesteps,
# sigmas,
# mu=mu,
# )
latent_model_input = self.pipeline._pack_latents(
latent_model_input,
batch_size=latent_model_input.shape[0],
num_channels_latents=latent_model_input.shape[1], # 16
height=height_latent, # 128
width=width_latent, # 128
)
noise_pred = self.unet(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), # [1, 4096, 64]
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
# todo make sure this doesnt change
timestep=timestep / 1000, # timestep is 1000 scale
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), # [1, 512, 4096]
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype), # [1, 768]
txt_ids=text_ids, # [1, 512, 3]
img_ids=latent_image_ids, # [1, 4096, 3]
guidance=guidance,
return_dict=False,
**kwargs,
)[0]
# unpack latents
noise_pred = self.pipeline._unpack_latents(
noise_pred,
height=height, # 1024
width=height, # 1024
vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why
)
elif self.is_v3:
noise_pred = self.unet(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
timestep=timestep,
@@ -1656,6 +1833,21 @@ class StableDiffusion:
embeds,
attention_mask=attention_mask, # not used
)
elif self.is_flux:
prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux(
self.tokenizer, # list
self.text_encoder, # list
prompt,
truncate=not long_prompts,
max_length=512,
dropout_prob=dropout_prob
)
pe = PromptEmbeds(
prompt_embeds
)
pe.pooled_embeds = pooled_prompt_embeds
return pe
elif isinstance(self.text_encoder, T5EncoderModel):
embeds, attention_mask = train_tools.encode_prompts_pixart(
@@ -1989,7 +2181,7 @@ class StableDiffusion:
named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True)
unet_lr = unet_lr if unet_lr is not None else default_lr
params = []
if self.is_pixart or self.is_auraflow:
if self.is_pixart or self.is_auraflow or self.is_flux:
for param in named_params.values():
if param.requires_grad:
params.append(param)
@@ -2035,7 +2227,7 @@ class StableDiffusion:
def save_device_state(self):
# saves the current device state for all modules
# this is useful for when we want to alter the state and restore it
if self.is_pixart or self.is_v3 or self.is_auraflow:
if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
unet_has_grad = self.unet.proj_out.weight.requires_grad
else:
unet_has_grad = self.unet.conv_in.weight.requires_grad