mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 02:01:29 +00:00
Some work on sd3 training. Not working
This commit is contained in:
@@ -25,6 +25,7 @@ from diffusers import (
|
||||
HeunDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
StableDiffusion3Pipeline
|
||||
)
|
||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||
import torch
|
||||
@@ -580,6 +581,58 @@ def encode_prompts_xl(
|
||||
|
||||
return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
|
||||
|
||||
def encode_prompts_sd3(
|
||||
tokenizers: list['CLIPTokenizer'],
|
||||
text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection', T5EncoderModel]],
|
||||
prompts: list[str],
|
||||
num_images_per_prompt: int = 1,
|
||||
truncate: bool = True,
|
||||
max_length=None,
|
||||
dropout_prob=0.0,
|
||||
pipeline: StableDiffusion3Pipeline = None,
|
||||
):
|
||||
text_embeds_list = []
|
||||
pooled_text_embeds = None # always text_encoder_2's pool
|
||||
|
||||
prompt_2 = prompts
|
||||
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
||||
|
||||
prompt_3 = prompts
|
||||
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
|
||||
|
||||
device = text_encoders[0].device
|
||||
|
||||
prompt_embed, pooled_prompt_embed = pipeline._get_clip_prompt_embeds(
|
||||
prompt=prompts,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=None,
|
||||
clip_model_index=0,
|
||||
)
|
||||
prompt_2_embed, pooled_prompt_2_embed = pipeline._get_clip_prompt_embeds(
|
||||
prompt=prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=None,
|
||||
clip_model_index=1,
|
||||
)
|
||||
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
|
||||
|
||||
t5_prompt_embed = pipeline._get_t5_prompt_embeds(
|
||||
prompt=prompt_3,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device
|
||||
)
|
||||
|
||||
clip_prompt_embeds = torch.nn.functional.pad(
|
||||
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
|
||||
)
|
||||
|
||||
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
|
||||
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
|
||||
# ref for long prompts https://github.com/huggingface/diffusers/issues/2136
|
||||
def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None):
|
||||
@@ -720,18 +773,22 @@ def concat_embeddings(
|
||||
|
||||
|
||||
def add_all_snr_to_noise_scheduler(noise_scheduler, device):
|
||||
if hasattr(noise_scheduler, "all_snr"):
|
||||
return
|
||||
# compute it
|
||||
with torch.no_grad():
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
||||
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
||||
alpha = sqrt_alphas_cumprod
|
||||
sigma = sqrt_one_minus_alphas_cumprod
|
||||
all_snr = (alpha / sigma) ** 2
|
||||
all_snr.requires_grad = False
|
||||
noise_scheduler.all_snr = all_snr.to(device)
|
||||
try:
|
||||
if hasattr(noise_scheduler, "all_snr"):
|
||||
return
|
||||
# compute it
|
||||
with torch.no_grad():
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
||||
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
||||
alpha = sqrt_alphas_cumprod
|
||||
sigma = sqrt_one_minus_alphas_cumprod
|
||||
all_snr = (alpha / sigma) ** 2
|
||||
all_snr.requires_grad = False
|
||||
noise_scheduler.all_snr = all_snr.to(device)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Failed to add all_snr to noise_scheduler")
|
||||
|
||||
|
||||
def get_all_snr(noise_scheduler, device):
|
||||
|
||||
Reference in New Issue
Block a user