Some work on sd3 training. Not working

This commit is contained in:
Jaret Burkett
2024-06-13 12:19:16 -06:00
parent cb5d28cba9
commit bd10d2d668
12 changed files with 306 additions and 36 deletions

View File

@@ -29,6 +29,7 @@ import gc
import torch
from jobs.process import BaseSDTrainProcess
from torchvision import transforms
import math
@@ -366,7 +367,29 @@ class SDTrainer(BaseSDTrainProcess):
loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
loss = loss_per_element
else:
if self.train_config.loss_type == "mae":
# handle flow matching ref https://github.com/huggingface/diffusers/blob/ec068f9b5bf7c65f93125ec889e0ff1792a00da1/examples/dreambooth/train_dreambooth_lora_sd3.py#L1485C17-L1495C100
if self.sd.is_v3:
target = noisy_latents.detach()
bsz = pred.shape[0]
# todo implement others
# weighing_scheme =
# 3 just do mode for now?
# if args.weighting_scheme == "sigma_sqrt":
sigmas = self.sd.noise_scheduler.get_sigmas(timesteps, pred.ndim, dtype, self.device_torch)
weighting = (sigmas ** -2.0).float()
# elif args.weighting_scheme == "logit_normal":
# # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
# u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
# weighting = torch.nn.functional.sigmoid(u)
# elif args.weighting_scheme == "mode":
# mode_scale = 1.29
# See sec 3.1 in the SD3 paper (20).
# u = torch.rand(size=(bsz,), device=pred.device)
# weighting = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
loss = (weighting.float() * (pred.float() - target.float()) ** 2).reshape(target.shape[0], -1)
elif self.train_config.loss_type == "mae":
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
else:
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")

View File

@@ -1244,6 +1244,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
conv_alpha=self.network_config.conv_alpha,
is_sdxl=self.model_config.is_xl or self.model_config.is_ssd,
is_v2=self.model_config.is_v2,
is_v3=self.model_config.is_v3,
is_ssd=self.model_config.is_ssd,
is_vega=self.model_config.is_vega,
dropout=self.network_config.dropout,

View File

@@ -1,7 +1,7 @@
torch
torchvision
safetensors
diffusers==0.26.3
diffusers
transformers
lycoris-lora==1.8.3
flatten_json

View File

@@ -336,6 +336,7 @@ class ModelConfig:
self.is_xl: bool = kwargs.get('is_xl', False)
self.is_pixart: bool = kwargs.get('is_pixart', False)
self.is_pixart_sigma: bool = kwargs.get('is_pixart', False)
self.is_v3: bool = kwargs.get('is_v3', False)
if self.is_pixart_sigma:
self.is_pixart = True
self.is_ssd: bool = kwargs.get('is_ssd', False)

View File

@@ -1353,6 +1353,8 @@ class LatentCachingMixin:
file_item.latent_space_version = self.sd.model_config.latent_space_version
elif self.sd.is_xl:
file_item.latent_space_version = 'sdxl'
elif self.sd.is_v3:
file_item.latent_space_version = 'sd3'
else:
file_item.latent_space_version = 'sd1'
file_item.is_caching_to_disk = to_disk

View File

@@ -152,6 +152,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
train_unet: Optional[bool] = True,
is_sdxl=False,
is_v2=False,
is_v3=False,
is_pixart: bool = False,
use_bias: bool = False,
is_lorm: bool = False,
@@ -200,6 +201,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.multiplier = multiplier
self.is_sdxl = is_sdxl
self.is_v2 = is_v2
self.is_v3 = is_v3
self.is_pixart = is_pixart
self.network_type = network_type
if self.network_type.lower() == "dora":
@@ -233,7 +235,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
target_replace_modules: List[torch.nn.Module],
) -> List[LoRAModule]:
unet_prefix = self.LORA_PREFIX_UNET
if is_pixart:
if is_pixart or is_v3:
unet_prefix = f"lora_transformer"
prefix = (
@@ -346,6 +348,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
target_modules += target_conv_modules
if is_v3:
target_modules = ["SD3Transformer2DModel"]
if train_unet:
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
else:

View File

@@ -48,7 +48,7 @@ class LoRAGenerator(torch.nn.Module):
head_size: int = 512,
num_mlp_layers: int = 1,
output_size: int = 768,
dropout: float = 0.5
dropout: float = 0.0
):
super().__init__()
self.input_size = input_size
@@ -131,8 +131,12 @@ class InstantLoRAMidModule(torch.nn.Module):
x_chunk = x_chunks[i]
# reshape
weight_chunk = weight_chunk.view(self.down_shape)
# run a simple lenear layer with the down weight
x_chunk = x_chunk @ weight_chunk.T
# check if is conv or linear
if len(weight_chunk.shape) == 4:
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk)
else:
# run a simple linear layer with the down weight
x_chunk = x_chunk @ weight_chunk.T
x_out.append(x_chunk)
x = torch.cat(x_out, dim=0)
return x
@@ -158,8 +162,12 @@ class InstantLoRAMidModule(torch.nn.Module):
x_chunk = x_chunks[i]
# reshape
weight_chunk = weight_chunk.view(self.up_shape)
# run a simple lenear layer with the down weight
x_chunk = x_chunk @ weight_chunk.T
# check if is conv or linear
if len(weight_chunk.shape) == 4:
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk)
else:
# run a simple linear layer with the down weight
x_chunk = x_chunk @ weight_chunk.T
x_out.append(x_chunk)
x = torch.cat(x_out, dim=0)
return x

View File

@@ -4,7 +4,6 @@ import torch
import sys
from PIL import Image
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from torch.nn import Parameter
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

View File

@@ -13,9 +13,12 @@ from diffusers import (
HeunDiscreteScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
LCMScheduler
LCMScheduler,
FlowMatchEulerDiscreteScheduler,
)
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
from k_diffusion.external import CompVisDenoiser
from toolkit.samplers.custom_lcm_scheduler import CustomLCMScheduler
@@ -112,6 +115,15 @@ def get_sampler(
scheduler_cls = LCMScheduler
elif sampler == "custom_lcm":
scheduler_cls = CustomLCMScheduler
elif sampler == "flowmatch":
scheduler_cls = CustomFlowMatchEulerDiscreteScheduler
config_to_use = {
"_class_name": "FlowMatchEulerDiscreteScheduler",
"_diffusers_version": "0.29.0.dev0",
"num_train_timesteps": 1000,
"shift": 3.0
}
config = copy.deepcopy(config_to_use)
config.update(sched_init_args)

View File

@@ -0,0 +1,32 @@
from typing import Union
from diffusers import FlowMatchEulerDiscreteScheduler
import torch
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
sigmas = self.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
n_dim = original_samples.ndim
sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * original_samples
return noisy_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
return sample

View File

@@ -40,13 +40,13 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline
import diffusers
from diffusers import \
AutoencoderKL, \
UNet2DConditionModel
from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler
from transformers import T5EncoderModel
from transformers import T5EncoderModel, BitsAndBytesConfig
from toolkit.util.pixart_sigma_patch import pixart_sigma_init_patched_inputs, PixArtSigmaPipeline
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
@@ -147,6 +147,7 @@ class StableDiffusion:
self.is_xl = model_config.is_xl
self.is_v2 = model_config.is_v2
self.is_ssd = model_config.is_ssd
self.is_v3 = model_config.is_v3
self.is_vega = model_config.is_vega
self.is_pixart = model_config.is_pixart
@@ -236,6 +237,64 @@ class StableDiffusion:
te1_state_dict['text_projection.weight'] = replacement_weight.to(self.device_torch, dtype=dtype)
flush()
print("Injecting alt weights")
elif self.model_config.is_v3:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
else:
pipln = StableDiffusion3Pipeline
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model_id = "stabilityai/stable-diffusion-3-medium"
text_encoder3 = T5EncoderModel.from_pretrained(
model_id,
subfolder="text_encoder_3",
# quantization_config=quantization_config,
revision="refs/pr/26",
device_map="cuda"
)
# see if path exists
if not os.path.exists(model_path) or os.path.isdir(model_path):
try:
# try to load with default diffusers
pipe = pipln.from_pretrained(
model_path,
dtype=dtype,
device=self.device_torch,
text_encoder_3=text_encoder3,
# variant="fp16",
use_safetensors=True,
revision="refs/pr/26",
repo_type="model",
ignore_patterns=["*.md", "*..gitattributes"],
**load_args
)
except Exception as e:
print(f"Error loading from pretrained: {e}")
raise e
else:
pipe = pipln.from_single_file(
model_path,
device=self.device_torch,
torch_dtype=self.torch_dtype,
text_encoder_3=text_encoder3,
)
flush()
text_encoders = [pipe.text_encoder, pipe.text_encoder_2, pipe.text_encoder_3]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2, pipe.tokenizer_3]
# replace the to function with a no-op since it throws an error instead of a warning
# text_encoders[2].to = lambda *args, **kwargs: None
for text_encoder in text_encoders:
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
text_encoder = text_encoders
elif self.model_config.is_pixart:
te_kwargs = {}
# handle quantization of TE
@@ -361,8 +420,8 @@ class StableDiffusion:
# add hacks to unet to help training
# pipe.unet = prepare_unet_for_training(pipe.unet)
if self.is_pixart:
# pixart doesnt use a unet
if self.is_pixart or self.is_v3:
# pixart and sd3 dont use a unet
self.unet = pipe.transformer
else:
self.unet: 'UNet2DConditionModel' = pipe.unet
@@ -487,6 +546,8 @@ class StableDiffusion:
Pipe = StableDiffusionKDiffusionXLPipeline
elif self.is_xl:
Pipe = StableDiffusionXLPipeline
elif self.is_v3:
Pipe = StableDiffusion3Pipeline
else:
Pipe = StableDiffusionPipeline
@@ -515,15 +576,30 @@ class StableDiffusion:
if self.is_xl:
pipeline = Pipe(
vae=self.vae,
unet=self.unet,
transformer=self.unet,
text_encoder=self.text_encoder[0],
text_encoder_2=self.text_encoder[1],
text_encoder_3=self.text_encoder[2],
tokenizer=self.tokenizer[0],
tokenizer_2=self.tokenizer[1],
tokenizer_3=self.tokenizer[2],
scheduler=noise_scheduler,
**extra_args
).to(self.device_torch)
pipeline.watermark = None
elif self.is_v3:
pipeline = Pipe(
vae=self.vae,
transformer=self.unet,
text_encoder=self.text_encoder[0],
text_encoder_2=self.text_encoder[1],
text_encoder_3=self.text_encoder[2],
tokenizer=self.tokenizer[0],
tokenizer_2=self.tokenizer[1],
tokenizer_3=self.tokenizer[2],
scheduler=noise_scheduler,
**extra_args
)
elif self.is_pixart:
pipeline = PixArtAlphaPipeline(
vae=self.vae,
@@ -576,7 +652,7 @@ class StableDiffusion:
if self.network is not None:
start_multiplier = self.network.multiplier
pipeline.to(self.device_torch)
# pipeline.to(self.device_torch)
with network:
with torch.no_grad():
@@ -744,6 +820,19 @@ class StableDiffusion:
latents=gen_config.latents,
**extra
).images[0]
elif self.is_v3:
img = pipeline(
prompt_embeds=conditional_embeds.text_embeds,
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
negative_prompt_embeds=unconditional_embeds.text_embeds,
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
**extra
).images[0]
elif self.is_pixart:
# needs attention masks for some reason
img = pipeline(
@@ -1004,6 +1093,20 @@ class StableDiffusion:
)
return torch.cat(out_chunks, dim=0)
def precondition_model_outputs_sd3(model_output, model_input, timestep_tensor):
mo_chunks = torch.chunk(model_output, model_output.shape[0], dim=0)
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
out_chunks = []
# unsqueeze if timestep is zero dim
for idx in range(model_output.shape[0]):
sigmas = self.noise_scheduler.get_sigmas(timestep_chunks[idx], n_dim=model_output.ndim, dtype=model_output.dtype, device=model_output.device)
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs.
out = mo_chunks[idx] * (-sigmas) + mi_chunks[idx]
out_chunks.append(out)
return torch.cat(out_chunks, dim=0)
if self.is_xl:
with torch.no_grad():
# 16, 6 for bs of 4
@@ -1177,12 +1280,22 @@ class StableDiffusion:
self.unet.to(self.device_torch)
if self.unet.dtype != self.torch_dtype:
self.unet = self.unet.to(dtype=self.torch_dtype)
noise_pred = self.unet(
latent_model_input.to(self.device_torch, self.torch_dtype),
timestep,
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
**kwargs,
).sample
if self.is_v3:
noise_pred = self.unet(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
timestep=timestep,
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype),
**kwargs,
).sample
noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep)
else:
noise_pred = self.unet(
latent_model_input.to(self.device_torch, self.torch_dtype),
timestep=timestep,
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
**kwargs,
).sample
conditional_pred = noise_pred
@@ -1343,6 +1456,19 @@ class StableDiffusion:
dropout_prob=dropout_prob,
)
)
if self.is_v3:
return PromptEmbeds(
train_tools.encode_prompts_sd3(
self.tokenizer,
self.text_encoder,
prompt,
num_images_per_prompt=num_images_per_prompt,
truncate=not long_prompts,
max_length=max_length,
dropout_prob=dropout_prob,
pipeline=self.pipeline,
)
)
elif self.is_pixart:
embeds, attention_mask = train_tools.encode_prompts_pixart(
self.tokenizer,
@@ -1735,7 +1861,7 @@ class StableDiffusion:
def save_device_state(self):
# saves the current device state for all modules
# this is useful for when we want to alter the state and restore it
if self.is_pixart:
if self.is_pixart or self.is_v3:
unet_has_grad = self.unet.proj_out.weight.requires_grad
else:
unet_has_grad = self.unet.conv_in.weight.requires_grad
@@ -1755,11 +1881,15 @@ class StableDiffusion:
if isinstance(self.text_encoder, list):
self.device_state['text_encoder']: List[dict] = []
for encoder in self.text_encoder:
try:
te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad
except:
te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
self.device_state['text_encoder'].append({
'training': encoder.training,
'device': encoder.device,
# todo there has to be a better way to do this
'requires_grad': encoder.text_model.final_layer_norm.weight.requires_grad
'requires_grad': te_has_grad
})
else:
if isinstance(self.text_encoder, T5EncoderModel):

View File

@@ -25,6 +25,7 @@ from diffusers import (
HeunDiscreteScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
StableDiffusion3Pipeline
)
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
import torch
@@ -580,6 +581,58 @@ def encode_prompts_xl(
return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
def encode_prompts_sd3(
tokenizers: list['CLIPTokenizer'],
text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection', T5EncoderModel]],
prompts: list[str],
num_images_per_prompt: int = 1,
truncate: bool = True,
max_length=None,
dropout_prob=0.0,
pipeline: StableDiffusion3Pipeline = None,
):
text_embeds_list = []
pooled_text_embeds = None # always text_encoder_2's pool
prompt_2 = prompts
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
prompt_3 = prompts
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
device = text_encoders[0].device
prompt_embed, pooled_prompt_embed = pipeline._get_clip_prompt_embeds(
prompt=prompts,
device=device,
num_images_per_prompt=num_images_per_prompt,
clip_skip=None,
clip_model_index=0,
)
prompt_2_embed, pooled_prompt_2_embed = pipeline._get_clip_prompt_embeds(
prompt=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
clip_skip=None,
clip_model_index=1,
)
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
t5_prompt_embed = pipeline._get_t5_prompt_embeds(
prompt=prompt_3,
num_images_per_prompt=num_images_per_prompt,
device=device
)
clip_prompt_embeds = torch.nn.functional.pad(
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
)
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
return prompt_embeds, pooled_prompt_embeds
# ref for long prompts https://github.com/huggingface/diffusers/issues/2136
def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None):
@@ -720,18 +773,22 @@ def concat_embeddings(
def add_all_snr_to_noise_scheduler(noise_scheduler, device):
if hasattr(noise_scheduler, "all_snr"):
return
# compute it
with torch.no_grad():
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
all_snr.requires_grad = False
noise_scheduler.all_snr = all_snr.to(device)
try:
if hasattr(noise_scheduler, "all_snr"):
return
# compute it
with torch.no_grad():
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
all_snr.requires_grad = False
noise_scheduler.all_snr = all_snr.to(device)
except Exception as e:
print(e)
print("Failed to add all_snr to noise_scheduler")
def get_all_snr(noise_scheduler, device):