SDXL should be working, but I broke something where it is not converging.

This commit is contained in:
Jaret Burkett
2023-07-25 13:50:59 -06:00
parent 52f02d53f1
commit cb70c03273
11 changed files with 458 additions and 166 deletions

View File

@@ -21,6 +21,8 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel
import torch
import re
from toolkit.stable_diffusion_model import PromptEmbeds
SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120
SCHEDULER_TIMESTEPS = 1000
@@ -377,3 +379,19 @@ def apply_noise_offset(noise, noise_offset):
return noise
noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device)
return noise
def concat_prompt_embeddings(
unconditional: PromptEmbeds,
conditional: PromptEmbeds,
n_imgs: int,
):
text_embeds = torch.cat(
[unconditional.text_embeds, conditional.text_embeds]
).repeat_interleave(n_imgs, dim=0)
pooled_embeds = None
if unconditional.pooled_embeds is not None and conditional.pooled_embeds is not None:
pooled_embeds = torch.cat(
[unconditional.pooled_embeds, conditional.pooled_embeds]
).repeat_interleave(n_imgs, dim=0)
return PromptEmbeds([text_embeds, pooled_embeds])