mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
SDXL should be working, but I broke something where it is not converging.
This commit is contained in:
@@ -21,6 +21,8 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel
|
||||
import torch
|
||||
import re
|
||||
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
|
||||
SCHEDULER_LINEAR_START = 0.00085
|
||||
SCHEDULER_LINEAR_END = 0.0120
|
||||
SCHEDULER_TIMESTEPS = 1000
|
||||
@@ -377,3 +379,19 @@ def apply_noise_offset(noise, noise_offset):
|
||||
return noise
|
||||
noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device)
|
||||
return noise
|
||||
|
||||
|
||||
def concat_prompt_embeddings(
|
||||
unconditional: PromptEmbeds,
|
||||
conditional: PromptEmbeds,
|
||||
n_imgs: int,
|
||||
):
|
||||
text_embeds = torch.cat(
|
||||
[unconditional.text_embeds, conditional.text_embeds]
|
||||
).repeat_interleave(n_imgs, dim=0)
|
||||
pooled_embeds = None
|
||||
if unconditional.pooled_embeds is not None and conditional.pooled_embeds is not None:
|
||||
pooled_embeds = torch.cat(
|
||||
[unconditional.pooled_embeds, conditional.pooled_embeds]
|
||||
).repeat_interleave(n_imgs, dim=0)
|
||||
return PromptEmbeds([text_embeds, pooled_embeds])
|
||||
|
||||
Reference in New Issue
Block a user