SDXL should be working, but I broke something where it is not converging.

This commit is contained in:
Jaret Burkett
2023-07-25 13:50:59 -06:00
parent 52f02d53f1
commit cb70c03273
11 changed files with 458 additions and 166 deletions

View File

@@ -3,19 +3,22 @@
import time
from collections import OrderedDict
import os
from typing import Optional
from toolkit.config_modules import SliderConfig
from toolkit.paths import REPOS_ROOT
import sys
from toolkit.stable_diffusion_model import PromptEmbeds
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import gc
from toolkit import train_tools
import torch
from leco import train_util, model_util
from leco.prompt_util import PromptEmbedsCache
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
@@ -29,7 +32,6 @@ def flush():
gc.collect()
class EncodedPromptPair:
def __init__(
self,
@@ -54,6 +56,19 @@ class EncodedPromptPair:
self.weight = weight
class PromptEmbedsCache: # 使いまわしたいので
prompts: dict[str, PromptEmbeds] = {}
def __setitem__(self, __name: str, __value: PromptEmbeds) -> None:
self.prompts[__name] = __value
def __getitem__(self, __name: str) -> Optional[PromptEmbeds]:
if __name in self.prompts:
return self.prompts[__name]
else:
return None
class EncodedAnchor:
def __init__(
self,
@@ -89,19 +104,17 @@ class TrainSliderProcess(BaseSDTrainProcess):
with torch.no_grad():
neutral = ""
for target in self.slider_config.targets:
# build the cache
for prompt in [
target.target_class,
target.positive,
target.negative,
neutral # empty neutral
]:
if cache[prompt] is None:
cache[prompt] = self.sd.encode_prompt(prompt)
for resolution in self.slider_config.resolutions:
width, height = resolution
# build the cache
for prompt in [
target.target_class,
target.positive,
target.negative,
neutral # empty neutral
]:
if cache[prompt] == None:
cache[prompt] = train_util.encode_prompts(
self.sd.tokenizer, self.sd.text_encoder, [prompt]
)
only_erase = len(target.positive.strip()) == 0
only_enhance = len(target.negative.strip()) == 0
@@ -184,9 +197,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
anchor.neg_prompt # empty neutral
]:
if cache[prompt] == None:
cache[prompt] = train_util.encode_prompts(
self.sd.tokenizer, self.sd.text_encoder, [prompt]
)
cache[prompt] = self.sd.encode_prompt(prompt)
anchor_pairs += [
EncodedAnchor(
@@ -198,7 +209,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
# move to cpu to save vram
# We don't need text encoder anymore, but keep it on cpu for sampling
self.sd.text_encoder.to("cpu")
# if text encoder is list
if isinstance(self.sd.text_encoder, list):
for encoder in self.sd.text_encoder:
encoder.to("cpu")
else:
self.sd.text_encoder.to("cpu")
self.prompt_cache = cache
self.prompt_pairs = prompt_pairs
self.anchor_pairs = anchor_pairs
@@ -220,6 +236,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
negative = prompt_pair.negative
positive = prompt_pair.positive
weight = prompt_pair.weight
multiplier = prompt_pair.multiplier
unet = self.sd.unet
noise_scheduler = self.sd.noise_scheduler
@@ -227,8 +244,20 @@ class TrainSliderProcess(BaseSDTrainProcess):
lr_scheduler = self.lr_scheduler
loss_function = torch.nn.MSELoss()
def get_noise_pred(p, n):
return self.predict_noise(
latents=denoised_latents,
text_embeddings=train_tools.concat_prompt_embeddings(
p, # unconditional
n, # positive
self.train_config.batch_size,
),
timestep=current_timestep,
guidance_scale=1,
)
# set network multiplier
self.network.multiplier = prompt_pair.multiplier
self.network.multiplier = multiplier
with torch.no_grad():
self.sd.noise_scheduler.set_timesteps(
@@ -254,9 +283,10 @@ class TrainSliderProcess(BaseSDTrainProcess):
with self.network:
assert self.network.is_active
self.network.multiplier = multiplier
denoised_latents = self.diffuse_some_steps(
latents, # pass simple noise latents
train_util.concat_embeddings(
train_tools.concat_prompt_embeddings(
positive, # unconditional
target_class, # target
self.train_config.batch_size,
@@ -272,43 +302,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
]
# with network: 0 weight LoRA is enabled outside "with network:"
positive_latents = train_util.predict_noise( # positive_latents
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
positive, # unconditional
negative, # positive
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
neutral_latents = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
positive, # unconditional
neutral, # neutral
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
unconditional_latents = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
positive, # unconditional
positive, # unconditional
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
positive_latents = get_noise_pred(positive, negative)
neutral_latents = get_noise_pred(positive, neutral)
unconditional_latents = get_noise_pred(positive, positive)
anchor_loss = None
if len(self.anchor_pairs) > 0:
@@ -317,51 +315,19 @@ class TrainSliderProcess(BaseSDTrainProcess):
torch.randint(0, len(self.anchor_pairs), (1,)).item()
]
with torch.no_grad():
anchor_target_noise = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
anchor.prompt,
anchor.neg_prompt,
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
anchor_target_noise = get_noise_pred(anchor.prompt, anchor.neg_prompt)
with self.network:
# anchor whatever weight prompt pair is using
pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0
self.network.multiplier = anchor.multiplier * pos_nem_mult
anchor_pred_noise = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
anchor.prompt,
anchor.neg_prompt,
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
anchor_pred_noise = get_noise_pred(anchor.prompt, anchor.neg_prompt)
self.network.multiplier = prompt_pair.multiplier
with self.network:
self.network.multiplier = prompt_pair.multiplier
target_latents = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
positive, # unconditional
target_class, # target
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
target_latents = get_noise_pred(positive, target_class)
# if self.logging_config.verbose:
# self.print("target_latents:", target_latents[0, 0, :5, :5])