diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 801337f2..cd90808b 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1234,7 +1234,10 @@ class BaseSDTrainProcess(BaseTrainProcess): torch.backends.cuda.enable_mem_efficient_sdp(True) if self.train_config.gradient_checkpointing: - unet.enable_gradient_checkpointing() + if self.sd.is_flux: + unet.gradient_checkpointing = True + else: + unet.enable_gradient_checkpointing() if isinstance(text_encoder, list): for te in text_encoder: if hasattr(te, 'enable_gradient_checkpointing'): @@ -1325,6 +1328,7 @@ class BaseSDTrainProcess(BaseTrainProcess): is_v3=self.model_config.is_v3, is_pixart=self.model_config.is_pixart, is_auraflow=self.model_config.is_auraflow, + is_flux=self.model_config.is_flux, is_ssd=self.model_config.is_ssd, is_vega=self.model_config.is_vega, dropout=self.network_config.dropout, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index ec3efca1..475c4845 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -367,6 +367,7 @@ class ModelConfig: self.is_pixart_sigma: bool = kwargs.get('is_pixart_sigma', False) self.is_auraflow: bool = kwargs.get('is_auraflow', False) self.is_v3: bool = kwargs.get('is_v3', False) + self.is_flux: bool = kwargs.get('is_flux', False) if self.is_pixart_sigma: self.is_pixart = True self.is_ssd: bool = kwargs.get('is_ssd', False) @@ -404,6 +405,9 @@ class ModelConfig: self.vae_dtype = kwargs.get("vae_dtype", self.dtype) self.te_device = kwargs.get("te_device", None) self.te_dtype = kwargs.get("te_dtype", self.dtype) + + # only for flux for now + self.quantize = kwargs.get("quantize", False) pass diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 273495c3..e9ee1ded 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -1361,6 +1361,8 @@ class LatentCachingMixin: file_item.latent_space_version = 'sd3' elif self.sd.is_auraflow: file_item.latent_space_version = 'sdxl' + elif self.sd.is_flux: + file_item.latent_space_version = 'flux' elif self.sd.model_config.is_pixart_sigma: file_item.latent_space_version = 'sdxl' else: diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 9cb377fc..a3d369bc 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -159,6 +159,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): is_v3=False, is_pixart: bool = False, is_auraflow: bool = False, + is_flux: bool = False, use_bias: bool = False, is_lorm: bool = False, ignore_if_contains = None, @@ -216,6 +217,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.is_v3 = is_v3 self.is_pixart = is_pixart self.is_auraflow = is_auraflow + self.is_flux = is_flux self.network_type = network_type if self.network_type.lower() == "dora": self.module_class = DoRAModule @@ -250,7 +252,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): target_replace_modules: List[torch.nn.Module], ) -> List[LoRAModule]: unet_prefix = self.LORA_PREFIX_UNET - if is_pixart or is_v3 or is_auraflow: + if is_pixart or is_v3 or is_auraflow or is_flux: unet_prefix = f"lora_transformer" prefix = ( @@ -293,6 +295,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): if self.transformer_only and self.is_pixart and is_unet: if "transformer_blocks" not in lora_name: skip = True + if self.transformer_only and self.is_flux and is_unet: + if "transformer_blocks" not in lora_name: + skip = True if (is_linear or is_conv2d) and not skip: @@ -393,6 +398,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): if is_auraflow: target_modules = ["AuraFlowTransformer2DModel"] + if is_flux: + target_modules = ["FluxTransformer2DModel"] + if train_unet: self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) else: @@ -454,7 +462,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr) if self.full_train_in_out: - if self.is_pixart or self.is_auraflow: + if self.is_pixart or self.is_auraflow or self.is_flux: all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())}) all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())}) else: diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index a728c111..1a419cc8 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -41,17 +41,21 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAda StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \ StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \ - StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel + StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \ + FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler import diffusers from diffusers import \ AutoencoderKL, \ UNet2DConditionModel from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler, PixArtSigmaPipeline -from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel +from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel, T5TokenizerFast +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT from toolkit.util.inverse_cfg import inverse_classifier_guidance +from optimum.quanto import freeze, qfloat8, quantize + # tell it to shut up diffusers.logging.set_verbosity(diffusers.logging.ERROR) @@ -78,6 +82,7 @@ DO_NOT_TRAIN_WEIGHTS = [ DeviceStatePreset = Literal['cache_latents', 'generate'] + class BlankNetwork: def __init__(self): @@ -101,10 +106,6 @@ def flush(): UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 # VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 -# if is type checking -if typing.TYPE_CHECKING: - from diffusers.schedulers import KarrasDiffusionSchedulers - from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection class StableDiffusion: @@ -158,6 +159,7 @@ class StableDiffusion: self.is_vega = model_config.is_vega self.is_pixart = model_config.is_pixart self.is_auraflow = model_config.is_auraflow + self.is_flux = model_config.is_flux self.use_text_encoder_1 = model_config.use_text_encoder_1 self.use_text_encoder_2 = model_config.use_text_encoder_2 @@ -443,6 +445,71 @@ class StableDiffusion: text_encoder.eval() pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) tokenizer = pipe.tokenizer + + elif self.model_config.is_flux: + print("Loading Flux model") + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler") + print("Loading vae") + vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype) + flush() + print("Loading transformer") + + transformer = FluxTransformer2DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype) + transformer.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize: + print("Quantizing transformer") + quantize(transformer, weights=qfloat8) + freeze(transformer) + flush() + + print("Loading t5") + text_encoder_2 = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder_2", torch_dtype=dtype) + tokenizer_2 = T5TokenizerFast.from_pretrained(model_path, subfolder="tokenizer_2", torch_dtype=dtype) + text_encoder_2.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize: + print("Quantizing T5") + quantize(text_encoder_2, weights=qfloat8) + freeze(text_encoder_2) + flush() + + print("Loading clip") + text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype) + tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", torch_dtype=dtype) + text_encoder.to(self.device_torch, dtype=dtype) + + print("making pipe") + pipe = FluxPipeline( + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=None, + tokenizer_2=tokenizer_2, + vae=vae, + transformer=None, + ) + pipe.text_encoder_2 = text_encoder_2 + pipe.transformer = transformer + + print("preparing") + + text_encoder = [pipe.text_encoder, pipe.text_encoder_2] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2] + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + text_encoder[1].to(self.device_torch) + text_encoder[1].requires_grad_(False) + text_encoder[1].eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() else: if self.custom_pipeline is not None: pipln = self.custom_pipeline @@ -515,7 +582,7 @@ class StableDiffusion: # add hacks to unet to help training # pipe.unet = prepare_unet_for_training(pipe.unet) - if self.is_pixart or self.is_v3 or self.is_auraflow: + if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux: # pixart and sd3 dont use a unet self.unet = pipe.transformer else: @@ -695,6 +762,18 @@ class StableDiffusion: **extra_args ).to(self.device_torch) pipeline.watermark = None + elif self.is_flux: + pipeline = FluxPipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder[0], + text_encoder_2=self.text_encoder[1], + tokenizer=self.tokenizer[0], + tokenizer_2=self.tokenizer[1], + scheduler=noise_scheduler, + **extra_args + ).to(self.device_torch) + pipeline.watermark = None elif self.is_v3: pipeline = Pipe( vae=self.vae, @@ -954,6 +1033,19 @@ class StableDiffusion: latents=gen_config.latents, **extra ).images[0] + elif self.is_flux: + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + # negative_prompt_embeds=unconditional_embeds.text_embeds, + # negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + **extra + ).images[0] elif self.is_pixart: # needs attention masks for some reason img = pipeline( @@ -1073,10 +1165,14 @@ class StableDiffusion: if width is None: width = pixel_width // VAE_SCALE_FACTOR + num_channels = self.unet.config['in_channels'] + if self.is_flux: + # has 64 channels in for some reason + num_channels = 16 noise = torch.randn( ( batch_size, - self.unet.config['in_channels'], + num_channels, height, width, ), @@ -1429,7 +1525,88 @@ class StableDiffusion: self.unet.to(self.device_torch) if self.unet.dtype != self.torch_dtype: self.unet = self.unet.to(dtype=self.torch_dtype) - if self.is_v3: + if self.is_flux: + with torch.no_grad(): + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) # 16 . Maybe dont subtract + # this is what diffusers does + text_ids = torch.zeros(latent_model_input.shape[0], text_embeddings.text_embeds.shape[1], 3).to( + device=self.device_torch, dtype=self.text_encoder[0].dtype + ) + # todo check these + # height = latent_model_input.shape[2] * VAE_SCALE_FACTOR + # width = latent_model_input.shape[3] * VAE_SCALE_FACTOR + height = latent_model_input.shape[2] * VAE_SCALE_FACTOR # 128 + width = latent_model_input.shape[3] * VAE_SCALE_FACTOR # 128 + + width_latent = latent_model_input.shape[3] + height_latent = latent_model_input.shape[2] + + latent_image_ids = self.pipeline._prepare_latent_image_ids( + batch_size=latent_model_input.shape[0], + height=height_latent, + width=width_latent, + device=self.device_torch, + dtype=self.torch_dtype, + ) + + # # handle guidance + guidance_scale = 1.0 # ? + if self.unet.config.guidance_embeds: + guidance = torch.tensor([guidance_scale], device=self.device_torch) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # not sure how to handle this + + # sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + # image_seq_len = latents.shape[1] + # mu = calculate_shift( + # image_seq_len, + # self.scheduler.config.base_image_seq_len, + # self.scheduler.config.max_image_seq_len, + # self.scheduler.config.base_shift, + # self.scheduler.config.max_shift, + # ) + # timesteps, num_inference_steps = retrieve_timesteps( + # self.scheduler, + # num_inference_steps, + # device, + # timesteps, + # sigmas, + # mu=mu, + # ) + latent_model_input = self.pipeline._pack_latents( + latent_model_input, + batch_size=latent_model_input.shape[0], + num_channels_latents=latent_model_input.shape[1], # 16 + height=height_latent, # 128 + width=width_latent, # 128 + ) + + + noise_pred = self.unet( + hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), # [1, 4096, 64] + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + # todo make sure this doesnt change + timestep=timestep / 1000, # timestep is 1000 scale + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), # [1, 512, 4096] + pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype), # [1, 768] + txt_ids=text_ids, # [1, 512, 3] + img_ids=latent_image_ids, # [1, 4096, 3] + guidance=guidance, + return_dict=False, + **kwargs, + )[0] + + # unpack latents + noise_pred = self.pipeline._unpack_latents( + noise_pred, + height=height, # 1024 + width=height, # 1024 + vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why + ) + elif self.is_v3: noise_pred = self.unet( hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), timestep=timestep, @@ -1656,6 +1833,21 @@ class StableDiffusion: embeds, attention_mask=attention_mask, # not used ) + elif self.is_flux: + prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux( + self.tokenizer, # list + self.text_encoder, # list + prompt, + truncate=not long_prompts, + max_length=512, + dropout_prob=dropout_prob + ) + pe = PromptEmbeds( + prompt_embeds + ) + pe.pooled_embeds = pooled_prompt_embeds + return pe + elif isinstance(self.text_encoder, T5EncoderModel): embeds, attention_mask = train_tools.encode_prompts_pixart( @@ -1989,7 +2181,7 @@ class StableDiffusion: named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True) unet_lr = unet_lr if unet_lr is not None else default_lr params = [] - if self.is_pixart or self.is_auraflow: + if self.is_pixart or self.is_auraflow or self.is_flux: for param in named_params.values(): if param.requires_grad: params.append(param) @@ -2035,7 +2227,7 @@ class StableDiffusion: def save_device_state(self): # saves the current device state for all modules # this is useful for when we want to alter the state and restore it - if self.is_pixart or self.is_v3 or self.is_auraflow: + if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux: unet_has_grad = self.unet.proj_out.weight.requires_grad else: unet_has_grad = self.unet.conv_in.weight.requires_grad diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 6a71e558..83b88444 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -3,7 +3,7 @@ import hashlib import json import os import time -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Union, List import sys from torch.cuda.amp import GradScaler @@ -766,6 +766,73 @@ def encode_prompts_auraflow( return prompt_embeds, prompt_attention_mask +def encode_prompts_flux( + tokenizer: List[Union['CLIPTokenizer','T5Tokenizer']], + text_encoder: List[Union['CLIPTextModel', 'T5EncoderModel']], + prompts: list[str], + truncate: bool = True, + max_length=None, + dropout_prob=0.0, +): + if max_length is None: + max_length = 512 + + if dropout_prob > 0.0: + # randomly drop out prompts + prompts = [ + prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts + ] + + device = text_encoder[0].device + dtype = text_encoder[0].dtype + + batch_size = len(prompts) + + # clip + text_inputs = tokenizer[0]( + prompts, + padding="max_length", + max_length=tokenizer[0].model_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + + prompt_embeds = text_encoder[0](text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + pooled_prompt_embeds = prompt_embeds.pooler_output + pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=dtype, device=device) + + # T5 + text_inputs = tokenizer[1]( + prompts, + padding="max_length", + max_length=max_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_embeds = text_encoder[1](text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = text_encoder[1].dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape) + # prompt_embeds = prompt_embeds * prompt_attention_mask + # _, seq_len, _ = prompt_embeds.shape + + # they dont do prompt attention mask? + # prompt_attention_mask = torch.ones((batch_size, seq_len), dtype=dtype, device=device) + + return prompt_embeds, pooled_prompt_embeds + # for XL def get_add_time_ids(