Some work on sd3 training. Not working

This commit is contained in:
Jaret Burkett
2024-06-13 12:19:16 -06:00
parent cb5d28cba9
commit bd10d2d668
12 changed files with 306 additions and 36 deletions

View File

@@ -25,6 +25,7 @@ from diffusers import (
HeunDiscreteScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
StableDiffusion3Pipeline
)
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
import torch
@@ -580,6 +581,58 @@ def encode_prompts_xl(
return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
def encode_prompts_sd3(
tokenizers: list['CLIPTokenizer'],
text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection', T5EncoderModel]],
prompts: list[str],
num_images_per_prompt: int = 1,
truncate: bool = True,
max_length=None,
dropout_prob=0.0,
pipeline: StableDiffusion3Pipeline = None,
):
text_embeds_list = []
pooled_text_embeds = None # always text_encoder_2's pool
prompt_2 = prompts
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
prompt_3 = prompts
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
device = text_encoders[0].device
prompt_embed, pooled_prompt_embed = pipeline._get_clip_prompt_embeds(
prompt=prompts,
device=device,
num_images_per_prompt=num_images_per_prompt,
clip_skip=None,
clip_model_index=0,
)
prompt_2_embed, pooled_prompt_2_embed = pipeline._get_clip_prompt_embeds(
prompt=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
clip_skip=None,
clip_model_index=1,
)
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
t5_prompt_embed = pipeline._get_t5_prompt_embeds(
prompt=prompt_3,
num_images_per_prompt=num_images_per_prompt,
device=device
)
clip_prompt_embeds = torch.nn.functional.pad(
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
)
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
return prompt_embeds, pooled_prompt_embeds
# ref for long prompts https://github.com/huggingface/diffusers/issues/2136
def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None):
@@ -720,18 +773,22 @@ def concat_embeddings(
def add_all_snr_to_noise_scheduler(noise_scheduler, device):
if hasattr(noise_scheduler, "all_snr"):
return
# compute it
with torch.no_grad():
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
all_snr.requires_grad = False
noise_scheduler.all_snr = all_snr.to(device)
try:
if hasattr(noise_scheduler, "all_snr"):
return
# compute it
with torch.no_grad():
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
all_snr.requires_grad = False
noise_scheduler.all_snr = all_snr.to(device)
except Exception as e:
print(e)
print("Failed to add all_snr to noise_scheduler")
def get_all_snr(noise_scheduler, device):