mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added flux training. Still a WIP. Wont train right without rectified flow working right
This commit is contained in:
@@ -1234,7 +1234,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||||
|
|
||||||
if self.train_config.gradient_checkpointing:
|
if self.train_config.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
if self.sd.is_flux:
|
||||||
|
unet.gradient_checkpointing = True
|
||||||
|
else:
|
||||||
|
unet.enable_gradient_checkpointing()
|
||||||
if isinstance(text_encoder, list):
|
if isinstance(text_encoder, list):
|
||||||
for te in text_encoder:
|
for te in text_encoder:
|
||||||
if hasattr(te, 'enable_gradient_checkpointing'):
|
if hasattr(te, 'enable_gradient_checkpointing'):
|
||||||
@@ -1325,6 +1328,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
is_v3=self.model_config.is_v3,
|
is_v3=self.model_config.is_v3,
|
||||||
is_pixart=self.model_config.is_pixart,
|
is_pixart=self.model_config.is_pixart,
|
||||||
is_auraflow=self.model_config.is_auraflow,
|
is_auraflow=self.model_config.is_auraflow,
|
||||||
|
is_flux=self.model_config.is_flux,
|
||||||
is_ssd=self.model_config.is_ssd,
|
is_ssd=self.model_config.is_ssd,
|
||||||
is_vega=self.model_config.is_vega,
|
is_vega=self.model_config.is_vega,
|
||||||
dropout=self.network_config.dropout,
|
dropout=self.network_config.dropout,
|
||||||
|
|||||||
@@ -367,6 +367,7 @@ class ModelConfig:
|
|||||||
self.is_pixart_sigma: bool = kwargs.get('is_pixart_sigma', False)
|
self.is_pixart_sigma: bool = kwargs.get('is_pixart_sigma', False)
|
||||||
self.is_auraflow: bool = kwargs.get('is_auraflow', False)
|
self.is_auraflow: bool = kwargs.get('is_auraflow', False)
|
||||||
self.is_v3: bool = kwargs.get('is_v3', False)
|
self.is_v3: bool = kwargs.get('is_v3', False)
|
||||||
|
self.is_flux: bool = kwargs.get('is_flux', False)
|
||||||
if self.is_pixart_sigma:
|
if self.is_pixart_sigma:
|
||||||
self.is_pixart = True
|
self.is_pixart = True
|
||||||
self.is_ssd: bool = kwargs.get('is_ssd', False)
|
self.is_ssd: bool = kwargs.get('is_ssd', False)
|
||||||
@@ -404,6 +405,9 @@ class ModelConfig:
|
|||||||
self.vae_dtype = kwargs.get("vae_dtype", self.dtype)
|
self.vae_dtype = kwargs.get("vae_dtype", self.dtype)
|
||||||
self.te_device = kwargs.get("te_device", None)
|
self.te_device = kwargs.get("te_device", None)
|
||||||
self.te_dtype = kwargs.get("te_dtype", self.dtype)
|
self.te_dtype = kwargs.get("te_dtype", self.dtype)
|
||||||
|
|
||||||
|
# only for flux for now
|
||||||
|
self.quantize = kwargs.get("quantize", False)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1361,6 +1361,8 @@ class LatentCachingMixin:
|
|||||||
file_item.latent_space_version = 'sd3'
|
file_item.latent_space_version = 'sd3'
|
||||||
elif self.sd.is_auraflow:
|
elif self.sd.is_auraflow:
|
||||||
file_item.latent_space_version = 'sdxl'
|
file_item.latent_space_version = 'sdxl'
|
||||||
|
elif self.sd.is_flux:
|
||||||
|
file_item.latent_space_version = 'flux'
|
||||||
elif self.sd.model_config.is_pixart_sigma:
|
elif self.sd.model_config.is_pixart_sigma:
|
||||||
file_item.latent_space_version = 'sdxl'
|
file_item.latent_space_version = 'sdxl'
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -159,6 +159,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
is_v3=False,
|
is_v3=False,
|
||||||
is_pixart: bool = False,
|
is_pixart: bool = False,
|
||||||
is_auraflow: bool = False,
|
is_auraflow: bool = False,
|
||||||
|
is_flux: bool = False,
|
||||||
use_bias: bool = False,
|
use_bias: bool = False,
|
||||||
is_lorm: bool = False,
|
is_lorm: bool = False,
|
||||||
ignore_if_contains = None,
|
ignore_if_contains = None,
|
||||||
@@ -216,6 +217,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
self.is_v3 = is_v3
|
self.is_v3 = is_v3
|
||||||
self.is_pixart = is_pixart
|
self.is_pixart = is_pixart
|
||||||
self.is_auraflow = is_auraflow
|
self.is_auraflow = is_auraflow
|
||||||
|
self.is_flux = is_flux
|
||||||
self.network_type = network_type
|
self.network_type = network_type
|
||||||
if self.network_type.lower() == "dora":
|
if self.network_type.lower() == "dora":
|
||||||
self.module_class = DoRAModule
|
self.module_class = DoRAModule
|
||||||
@@ -250,7 +252,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
target_replace_modules: List[torch.nn.Module],
|
target_replace_modules: List[torch.nn.Module],
|
||||||
) -> List[LoRAModule]:
|
) -> List[LoRAModule]:
|
||||||
unet_prefix = self.LORA_PREFIX_UNET
|
unet_prefix = self.LORA_PREFIX_UNET
|
||||||
if is_pixart or is_v3 or is_auraflow:
|
if is_pixart or is_v3 or is_auraflow or is_flux:
|
||||||
unet_prefix = f"lora_transformer"
|
unet_prefix = f"lora_transformer"
|
||||||
|
|
||||||
prefix = (
|
prefix = (
|
||||||
@@ -293,6 +295,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
if self.transformer_only and self.is_pixart and is_unet:
|
if self.transformer_only and self.is_pixart and is_unet:
|
||||||
if "transformer_blocks" not in lora_name:
|
if "transformer_blocks" not in lora_name:
|
||||||
skip = True
|
skip = True
|
||||||
|
if self.transformer_only and self.is_flux and is_unet:
|
||||||
|
if "transformer_blocks" not in lora_name:
|
||||||
|
skip = True
|
||||||
|
|
||||||
if (is_linear or is_conv2d) and not skip:
|
if (is_linear or is_conv2d) and not skip:
|
||||||
|
|
||||||
@@ -393,6 +398,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
if is_auraflow:
|
if is_auraflow:
|
||||||
target_modules = ["AuraFlowTransformer2DModel"]
|
target_modules = ["AuraFlowTransformer2DModel"]
|
||||||
|
|
||||||
|
if is_flux:
|
||||||
|
target_modules = ["FluxTransformer2DModel"]
|
||||||
|
|
||||||
if train_unet:
|
if train_unet:
|
||||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||||
else:
|
else:
|
||||||
@@ -454,7 +462,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)
|
all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)
|
||||||
|
|
||||||
if self.full_train_in_out:
|
if self.full_train_in_out:
|
||||||
if self.is_pixart or self.is_auraflow:
|
if self.is_pixart or self.is_auraflow or self.is_flux:
|
||||||
all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
|
all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
|
||||||
all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
|
all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -41,17 +41,21 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAda
|
|||||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
|
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
|
||||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
||||||
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \
|
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \
|
||||||
StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel
|
StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \
|
||||||
|
FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||||
import diffusers
|
import diffusers
|
||||||
from diffusers import \
|
from diffusers import \
|
||||||
AutoencoderKL, \
|
AutoencoderKL, \
|
||||||
UNet2DConditionModel
|
UNet2DConditionModel
|
||||||
from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler, PixArtSigmaPipeline
|
from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler, PixArtSigmaPipeline
|
||||||
from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel
|
from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel, T5TokenizerFast
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||||
|
|
||||||
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
||||||
from toolkit.util.inverse_cfg import inverse_classifier_guidance
|
from toolkit.util.inverse_cfg import inverse_classifier_guidance
|
||||||
|
|
||||||
|
from optimum.quanto import freeze, qfloat8, quantize
|
||||||
|
|
||||||
# tell it to shut up
|
# tell it to shut up
|
||||||
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
||||||
|
|
||||||
@@ -78,6 +82,7 @@ DO_NOT_TRAIN_WEIGHTS = [
|
|||||||
DeviceStatePreset = Literal['cache_latents', 'generate']
|
DeviceStatePreset = Literal['cache_latents', 'generate']
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BlankNetwork:
|
class BlankNetwork:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -101,10 +106,6 @@ def flush():
|
|||||||
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
|
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
|
||||||
# VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
# VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
||||||
|
|
||||||
# if is type checking
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion:
|
class StableDiffusion:
|
||||||
@@ -158,6 +159,7 @@ class StableDiffusion:
|
|||||||
self.is_vega = model_config.is_vega
|
self.is_vega = model_config.is_vega
|
||||||
self.is_pixart = model_config.is_pixart
|
self.is_pixart = model_config.is_pixart
|
||||||
self.is_auraflow = model_config.is_auraflow
|
self.is_auraflow = model_config.is_auraflow
|
||||||
|
self.is_flux = model_config.is_flux
|
||||||
|
|
||||||
self.use_text_encoder_1 = model_config.use_text_encoder_1
|
self.use_text_encoder_1 = model_config.use_text_encoder_1
|
||||||
self.use_text_encoder_2 = model_config.use_text_encoder_2
|
self.use_text_encoder_2 = model_config.use_text_encoder_2
|
||||||
@@ -443,6 +445,71 @@ class StableDiffusion:
|
|||||||
text_encoder.eval()
|
text_encoder.eval()
|
||||||
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
||||||
tokenizer = pipe.tokenizer
|
tokenizer = pipe.tokenizer
|
||||||
|
|
||||||
|
elif self.model_config.is_flux:
|
||||||
|
print("Loading Flux model")
|
||||||
|
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
|
||||||
|
print("Loading vae")
|
||||||
|
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype)
|
||||||
|
flush()
|
||||||
|
print("Loading transformer")
|
||||||
|
|
||||||
|
transformer = FluxTransformer2DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype)
|
||||||
|
transformer.to(self.device_torch, dtype=dtype)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
if self.model_config.quantize:
|
||||||
|
print("Quantizing transformer")
|
||||||
|
quantize(transformer, weights=qfloat8)
|
||||||
|
freeze(transformer)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
print("Loading t5")
|
||||||
|
text_encoder_2 = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||||
|
tokenizer_2 = T5TokenizerFast.from_pretrained(model_path, subfolder="tokenizer_2", torch_dtype=dtype)
|
||||||
|
text_encoder_2.to(self.device_torch, dtype=dtype)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
if self.model_config.quantize:
|
||||||
|
print("Quantizing T5")
|
||||||
|
quantize(text_encoder_2, weights=qfloat8)
|
||||||
|
freeze(text_encoder_2)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
print("Loading clip")
|
||||||
|
text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||||
|
text_encoder.to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
|
print("making pipe")
|
||||||
|
pipe = FluxPipeline(
|
||||||
|
scheduler=scheduler,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
text_encoder_2=None,
|
||||||
|
tokenizer_2=tokenizer_2,
|
||||||
|
vae=vae,
|
||||||
|
transformer=None,
|
||||||
|
)
|
||||||
|
pipe.text_encoder_2 = text_encoder_2
|
||||||
|
pipe.transformer = transformer
|
||||||
|
|
||||||
|
print("preparing")
|
||||||
|
|
||||||
|
text_encoder = [pipe.text_encoder, pipe.text_encoder_2]
|
||||||
|
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
||||||
|
|
||||||
|
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||||
|
|
||||||
|
flush()
|
||||||
|
text_encoder[0].to(self.device_torch)
|
||||||
|
text_encoder[0].requires_grad_(False)
|
||||||
|
text_encoder[0].eval()
|
||||||
|
text_encoder[1].to(self.device_torch)
|
||||||
|
text_encoder[1].requires_grad_(False)
|
||||||
|
text_encoder[1].eval()
|
||||||
|
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||||
|
flush()
|
||||||
else:
|
else:
|
||||||
if self.custom_pipeline is not None:
|
if self.custom_pipeline is not None:
|
||||||
pipln = self.custom_pipeline
|
pipln = self.custom_pipeline
|
||||||
@@ -515,7 +582,7 @@ class StableDiffusion:
|
|||||||
# add hacks to unet to help training
|
# add hacks to unet to help training
|
||||||
# pipe.unet = prepare_unet_for_training(pipe.unet)
|
# pipe.unet = prepare_unet_for_training(pipe.unet)
|
||||||
|
|
||||||
if self.is_pixart or self.is_v3 or self.is_auraflow:
|
if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
|
||||||
# pixart and sd3 dont use a unet
|
# pixart and sd3 dont use a unet
|
||||||
self.unet = pipe.transformer
|
self.unet = pipe.transformer
|
||||||
else:
|
else:
|
||||||
@@ -695,6 +762,18 @@ class StableDiffusion:
|
|||||||
**extra_args
|
**extra_args
|
||||||
).to(self.device_torch)
|
).to(self.device_torch)
|
||||||
pipeline.watermark = None
|
pipeline.watermark = None
|
||||||
|
elif self.is_flux:
|
||||||
|
pipeline = FluxPipeline(
|
||||||
|
vae=self.vae,
|
||||||
|
transformer=self.unet,
|
||||||
|
text_encoder=self.text_encoder[0],
|
||||||
|
text_encoder_2=self.text_encoder[1],
|
||||||
|
tokenizer=self.tokenizer[0],
|
||||||
|
tokenizer_2=self.tokenizer[1],
|
||||||
|
scheduler=noise_scheduler,
|
||||||
|
**extra_args
|
||||||
|
).to(self.device_torch)
|
||||||
|
pipeline.watermark = None
|
||||||
elif self.is_v3:
|
elif self.is_v3:
|
||||||
pipeline = Pipe(
|
pipeline = Pipe(
|
||||||
vae=self.vae,
|
vae=self.vae,
|
||||||
@@ -954,6 +1033,19 @@ class StableDiffusion:
|
|||||||
latents=gen_config.latents,
|
latents=gen_config.latents,
|
||||||
**extra
|
**extra
|
||||||
).images[0]
|
).images[0]
|
||||||
|
elif self.is_flux:
|
||||||
|
img = pipeline(
|
||||||
|
prompt_embeds=conditional_embeds.text_embeds,
|
||||||
|
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
||||||
|
# negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||||
|
# negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
|
||||||
|
height=gen_config.height,
|
||||||
|
width=gen_config.width,
|
||||||
|
num_inference_steps=gen_config.num_inference_steps,
|
||||||
|
guidance_scale=gen_config.guidance_scale,
|
||||||
|
latents=gen_config.latents,
|
||||||
|
**extra
|
||||||
|
).images[0]
|
||||||
elif self.is_pixart:
|
elif self.is_pixart:
|
||||||
# needs attention masks for some reason
|
# needs attention masks for some reason
|
||||||
img = pipeline(
|
img = pipeline(
|
||||||
@@ -1073,10 +1165,14 @@ class StableDiffusion:
|
|||||||
if width is None:
|
if width is None:
|
||||||
width = pixel_width // VAE_SCALE_FACTOR
|
width = pixel_width // VAE_SCALE_FACTOR
|
||||||
|
|
||||||
|
num_channels = self.unet.config['in_channels']
|
||||||
|
if self.is_flux:
|
||||||
|
# has 64 channels in for some reason
|
||||||
|
num_channels = 16
|
||||||
noise = torch.randn(
|
noise = torch.randn(
|
||||||
(
|
(
|
||||||
batch_size,
|
batch_size,
|
||||||
self.unet.config['in_channels'],
|
num_channels,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
),
|
),
|
||||||
@@ -1429,7 +1525,88 @@ class StableDiffusion:
|
|||||||
self.unet.to(self.device_torch)
|
self.unet.to(self.device_torch)
|
||||||
if self.unet.dtype != self.torch_dtype:
|
if self.unet.dtype != self.torch_dtype:
|
||||||
self.unet = self.unet.to(dtype=self.torch_dtype)
|
self.unet = self.unet.to(dtype=self.torch_dtype)
|
||||||
if self.is_v3:
|
if self.is_flux:
|
||||||
|
with torch.no_grad():
|
||||||
|
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) # 16 . Maybe dont subtract
|
||||||
|
# this is what diffusers does
|
||||||
|
text_ids = torch.zeros(latent_model_input.shape[0], text_embeddings.text_embeds.shape[1], 3).to(
|
||||||
|
device=self.device_torch, dtype=self.text_encoder[0].dtype
|
||||||
|
)
|
||||||
|
# todo check these
|
||||||
|
# height = latent_model_input.shape[2] * VAE_SCALE_FACTOR
|
||||||
|
# width = latent_model_input.shape[3] * VAE_SCALE_FACTOR
|
||||||
|
height = latent_model_input.shape[2] * VAE_SCALE_FACTOR # 128
|
||||||
|
width = latent_model_input.shape[3] * VAE_SCALE_FACTOR # 128
|
||||||
|
|
||||||
|
width_latent = latent_model_input.shape[3]
|
||||||
|
height_latent = latent_model_input.shape[2]
|
||||||
|
|
||||||
|
latent_image_ids = self.pipeline._prepare_latent_image_ids(
|
||||||
|
batch_size=latent_model_input.shape[0],
|
||||||
|
height=height_latent,
|
||||||
|
width=width_latent,
|
||||||
|
device=self.device_torch,
|
||||||
|
dtype=self.torch_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# # handle guidance
|
||||||
|
guidance_scale = 1.0 # ?
|
||||||
|
if self.unet.config.guidance_embeds:
|
||||||
|
guidance = torch.tensor([guidance_scale], device=self.device_torch)
|
||||||
|
guidance = guidance.expand(latents.shape[0])
|
||||||
|
else:
|
||||||
|
guidance = None
|
||||||
|
|
||||||
|
# not sure how to handle this
|
||||||
|
|
||||||
|
# sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||||
|
# image_seq_len = latents.shape[1]
|
||||||
|
# mu = calculate_shift(
|
||||||
|
# image_seq_len,
|
||||||
|
# self.scheduler.config.base_image_seq_len,
|
||||||
|
# self.scheduler.config.max_image_seq_len,
|
||||||
|
# self.scheduler.config.base_shift,
|
||||||
|
# self.scheduler.config.max_shift,
|
||||||
|
# )
|
||||||
|
# timesteps, num_inference_steps = retrieve_timesteps(
|
||||||
|
# self.scheduler,
|
||||||
|
# num_inference_steps,
|
||||||
|
# device,
|
||||||
|
# timesteps,
|
||||||
|
# sigmas,
|
||||||
|
# mu=mu,
|
||||||
|
# )
|
||||||
|
latent_model_input = self.pipeline._pack_latents(
|
||||||
|
latent_model_input,
|
||||||
|
batch_size=latent_model_input.shape[0],
|
||||||
|
num_channels_latents=latent_model_input.shape[1], # 16
|
||||||
|
height=height_latent, # 128
|
||||||
|
width=width_latent, # 128
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
noise_pred = self.unet(
|
||||||
|
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), # [1, 4096, 64]
|
||||||
|
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||||
|
# todo make sure this doesnt change
|
||||||
|
timestep=timestep / 1000, # timestep is 1000 scale
|
||||||
|
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), # [1, 512, 4096]
|
||||||
|
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype), # [1, 768]
|
||||||
|
txt_ids=text_ids, # [1, 512, 3]
|
||||||
|
img_ids=latent_image_ids, # [1, 4096, 3]
|
||||||
|
guidance=guidance,
|
||||||
|
return_dict=False,
|
||||||
|
**kwargs,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# unpack latents
|
||||||
|
noise_pred = self.pipeline._unpack_latents(
|
||||||
|
noise_pred,
|
||||||
|
height=height, # 1024
|
||||||
|
width=height, # 1024
|
||||||
|
vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why
|
||||||
|
)
|
||||||
|
elif self.is_v3:
|
||||||
noise_pred = self.unet(
|
noise_pred = self.unet(
|
||||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
@@ -1656,6 +1833,21 @@ class StableDiffusion:
|
|||||||
embeds,
|
embeds,
|
||||||
attention_mask=attention_mask, # not used
|
attention_mask=attention_mask, # not used
|
||||||
)
|
)
|
||||||
|
elif self.is_flux:
|
||||||
|
prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux(
|
||||||
|
self.tokenizer, # list
|
||||||
|
self.text_encoder, # list
|
||||||
|
prompt,
|
||||||
|
truncate=not long_prompts,
|
||||||
|
max_length=512,
|
||||||
|
dropout_prob=dropout_prob
|
||||||
|
)
|
||||||
|
pe = PromptEmbeds(
|
||||||
|
prompt_embeds
|
||||||
|
)
|
||||||
|
pe.pooled_embeds = pooled_prompt_embeds
|
||||||
|
return pe
|
||||||
|
|
||||||
|
|
||||||
elif isinstance(self.text_encoder, T5EncoderModel):
|
elif isinstance(self.text_encoder, T5EncoderModel):
|
||||||
embeds, attention_mask = train_tools.encode_prompts_pixart(
|
embeds, attention_mask = train_tools.encode_prompts_pixart(
|
||||||
@@ -1989,7 +2181,7 @@ class StableDiffusion:
|
|||||||
named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True)
|
named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True)
|
||||||
unet_lr = unet_lr if unet_lr is not None else default_lr
|
unet_lr = unet_lr if unet_lr is not None else default_lr
|
||||||
params = []
|
params = []
|
||||||
if self.is_pixart or self.is_auraflow:
|
if self.is_pixart or self.is_auraflow or self.is_flux:
|
||||||
for param in named_params.values():
|
for param in named_params.values():
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
params.append(param)
|
params.append(param)
|
||||||
@@ -2035,7 +2227,7 @@ class StableDiffusion:
|
|||||||
def save_device_state(self):
|
def save_device_state(self):
|
||||||
# saves the current device state for all modules
|
# saves the current device state for all modules
|
||||||
# this is useful for when we want to alter the state and restore it
|
# this is useful for when we want to alter the state and restore it
|
||||||
if self.is_pixart or self.is_v3 or self.is_auraflow:
|
if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
|
||||||
unet_has_grad = self.unet.proj_out.weight.requires_grad
|
unet_has_grad = self.unet.proj_out.weight.requires_grad
|
||||||
else:
|
else:
|
||||||
unet_has_grad = self.unet.conv_in.weight.requires_grad
|
unet_has_grad = self.unet.conv_in.weight.requires_grad
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Union, List
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
@@ -766,6 +766,73 @@ def encode_prompts_auraflow(
|
|||||||
|
|
||||||
return prompt_embeds, prompt_attention_mask
|
return prompt_embeds, prompt_attention_mask
|
||||||
|
|
||||||
|
def encode_prompts_flux(
|
||||||
|
tokenizer: List[Union['CLIPTokenizer','T5Tokenizer']],
|
||||||
|
text_encoder: List[Union['CLIPTextModel', 'T5EncoderModel']],
|
||||||
|
prompts: list[str],
|
||||||
|
truncate: bool = True,
|
||||||
|
max_length=None,
|
||||||
|
dropout_prob=0.0,
|
||||||
|
):
|
||||||
|
if max_length is None:
|
||||||
|
max_length = 512
|
||||||
|
|
||||||
|
if dropout_prob > 0.0:
|
||||||
|
# randomly drop out prompts
|
||||||
|
prompts = [
|
||||||
|
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts
|
||||||
|
]
|
||||||
|
|
||||||
|
device = text_encoder[0].device
|
||||||
|
dtype = text_encoder[0].dtype
|
||||||
|
|
||||||
|
batch_size = len(prompts)
|
||||||
|
|
||||||
|
# clip
|
||||||
|
text_inputs = tokenizer[0](
|
||||||
|
prompts,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=tokenizer[0].model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_overflowing_tokens=False,
|
||||||
|
return_length=False,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
text_input_ids = text_inputs.input_ids
|
||||||
|
|
||||||
|
prompt_embeds = text_encoder[0](text_input_ids.to(device), output_hidden_states=False)
|
||||||
|
|
||||||
|
# Use pooled output of CLIPTextModel
|
||||||
|
pooled_prompt_embeds = prompt_embeds.pooler_output
|
||||||
|
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# T5
|
||||||
|
text_inputs = tokenizer[1](
|
||||||
|
prompts,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_length=False,
|
||||||
|
return_overflowing_tokens=False,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
text_input_ids = text_inputs.input_ids
|
||||||
|
|
||||||
|
prompt_embeds = text_encoder[1](text_input_ids.to(device), output_hidden_states=False)[0]
|
||||||
|
|
||||||
|
dtype = text_encoder[1].dtype
|
||||||
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
|
||||||
|
# prompt_embeds = prompt_embeds * prompt_attention_mask
|
||||||
|
# _, seq_len, _ = prompt_embeds.shape
|
||||||
|
|
||||||
|
# they dont do prompt attention mask?
|
||||||
|
# prompt_attention_mask = torch.ones((batch_size, seq_len), dtype=dtype, device=device)
|
||||||
|
|
||||||
|
return prompt_embeds, pooled_prompt_embeds
|
||||||
|
|
||||||
|
|
||||||
# for XL
|
# for XL
|
||||||
def get_add_time_ids(
|
def get_add_time_ids(
|
||||||
|
|||||||
Reference in New Issue
Block a user