mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
SDXL should be working, but I broke something where it is not converging.
This commit is contained in:
@@ -3,19 +3,22 @@
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from toolkit.config_modules import SliderConfig
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
import sys
|
||||
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import gc
|
||||
from toolkit import train_tools
|
||||
|
||||
import torch
|
||||
from leco import train_util, model_util
|
||||
from leco.prompt_util import PromptEmbedsCache
|
||||
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
|
||||
|
||||
|
||||
@@ -29,7 +32,6 @@ def flush():
|
||||
gc.collect()
|
||||
|
||||
|
||||
|
||||
class EncodedPromptPair:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -54,6 +56,19 @@ class EncodedPromptPair:
|
||||
self.weight = weight
|
||||
|
||||
|
||||
class PromptEmbedsCache: # 使いまわしたいので
|
||||
prompts: dict[str, PromptEmbeds] = {}
|
||||
|
||||
def __setitem__(self, __name: str, __value: PromptEmbeds) -> None:
|
||||
self.prompts[__name] = __value
|
||||
|
||||
def __getitem__(self, __name: str) -> Optional[PromptEmbeds]:
|
||||
if __name in self.prompts:
|
||||
return self.prompts[__name]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class EncodedAnchor:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -89,19 +104,17 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
with torch.no_grad():
|
||||
neutral = ""
|
||||
for target in self.slider_config.targets:
|
||||
# build the cache
|
||||
for prompt in [
|
||||
target.target_class,
|
||||
target.positive,
|
||||
target.negative,
|
||||
neutral # empty neutral
|
||||
]:
|
||||
if cache[prompt] is None:
|
||||
cache[prompt] = self.sd.encode_prompt(prompt)
|
||||
for resolution in self.slider_config.resolutions:
|
||||
width, height = resolution
|
||||
# build the cache
|
||||
for prompt in [
|
||||
target.target_class,
|
||||
target.positive,
|
||||
target.negative,
|
||||
neutral # empty neutral
|
||||
]:
|
||||
if cache[prompt] == None:
|
||||
cache[prompt] = train_util.encode_prompts(
|
||||
self.sd.tokenizer, self.sd.text_encoder, [prompt]
|
||||
)
|
||||
only_erase = len(target.positive.strip()) == 0
|
||||
only_enhance = len(target.negative.strip()) == 0
|
||||
|
||||
@@ -184,9 +197,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
anchor.neg_prompt # empty neutral
|
||||
]:
|
||||
if cache[prompt] == None:
|
||||
cache[prompt] = train_util.encode_prompts(
|
||||
self.sd.tokenizer, self.sd.text_encoder, [prompt]
|
||||
)
|
||||
cache[prompt] = self.sd.encode_prompt(prompt)
|
||||
|
||||
anchor_pairs += [
|
||||
EncodedAnchor(
|
||||
@@ -198,7 +209,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
|
||||
# move to cpu to save vram
|
||||
# We don't need text encoder anymore, but keep it on cpu for sampling
|
||||
self.sd.text_encoder.to("cpu")
|
||||
# if text encoder is list
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for encoder in self.sd.text_encoder:
|
||||
encoder.to("cpu")
|
||||
else:
|
||||
self.sd.text_encoder.to("cpu")
|
||||
self.prompt_cache = cache
|
||||
self.prompt_pairs = prompt_pairs
|
||||
self.anchor_pairs = anchor_pairs
|
||||
@@ -220,6 +236,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
negative = prompt_pair.negative
|
||||
positive = prompt_pair.positive
|
||||
weight = prompt_pair.weight
|
||||
multiplier = prompt_pair.multiplier
|
||||
|
||||
unet = self.sd.unet
|
||||
noise_scheduler = self.sd.noise_scheduler
|
||||
@@ -227,8 +244,20 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
lr_scheduler = self.lr_scheduler
|
||||
loss_function = torch.nn.MSELoss()
|
||||
|
||||
def get_noise_pred(p, n):
|
||||
return self.predict_noise(
|
||||
latents=denoised_latents,
|
||||
text_embeddings=train_tools.concat_prompt_embeddings(
|
||||
p, # unconditional
|
||||
n, # positive
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
timestep=current_timestep,
|
||||
guidance_scale=1,
|
||||
)
|
||||
|
||||
# set network multiplier
|
||||
self.network.multiplier = prompt_pair.multiplier
|
||||
self.network.multiplier = multiplier
|
||||
|
||||
with torch.no_grad():
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
@@ -254,9 +283,10 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
|
||||
with self.network:
|
||||
assert self.network.is_active
|
||||
self.network.multiplier = multiplier
|
||||
denoised_latents = self.diffuse_some_steps(
|
||||
latents, # pass simple noise latents
|
||||
train_util.concat_embeddings(
|
||||
train_tools.concat_prompt_embeddings(
|
||||
positive, # unconditional
|
||||
target_class, # target
|
||||
self.train_config.batch_size,
|
||||
@@ -272,43 +302,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
|
||||
]
|
||||
|
||||
# with network: 0 weight LoRA is enabled outside "with network:"
|
||||
positive_latents = train_util.predict_noise( # positive_latents
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
train_util.concat_embeddings(
|
||||
positive, # unconditional
|
||||
negative, # positive
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
guidance_scale=1,
|
||||
).to("cpu", dtype=torch.float32)
|
||||
neutral_latents = train_util.predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
train_util.concat_embeddings(
|
||||
positive, # unconditional
|
||||
neutral, # neutral
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
guidance_scale=1,
|
||||
).to("cpu", dtype=torch.float32)
|
||||
unconditional_latents = train_util.predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
train_util.concat_embeddings(
|
||||
positive, # unconditional
|
||||
positive, # unconditional
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
guidance_scale=1,
|
||||
).to("cpu", dtype=torch.float32)
|
||||
positive_latents = get_noise_pred(positive, negative)
|
||||
|
||||
neutral_latents = get_noise_pred(positive, neutral)
|
||||
|
||||
unconditional_latents = get_noise_pred(positive, positive)
|
||||
|
||||
anchor_loss = None
|
||||
if len(self.anchor_pairs) > 0:
|
||||
@@ -317,51 +315,19 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
torch.randint(0, len(self.anchor_pairs), (1,)).item()
|
||||
]
|
||||
with torch.no_grad():
|
||||
anchor_target_noise = train_util.predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
train_util.concat_embeddings(
|
||||
anchor.prompt,
|
||||
anchor.neg_prompt,
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
guidance_scale=1,
|
||||
).to("cpu", dtype=torch.float32)
|
||||
anchor_target_noise = get_noise_pred(anchor.prompt, anchor.neg_prompt)
|
||||
with self.network:
|
||||
# anchor whatever weight prompt pair is using
|
||||
pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0
|
||||
self.network.multiplier = anchor.multiplier * pos_nem_mult
|
||||
anchor_pred_noise = train_util.predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
train_util.concat_embeddings(
|
||||
anchor.prompt,
|
||||
anchor.neg_prompt,
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
guidance_scale=1,
|
||||
).to("cpu", dtype=torch.float32)
|
||||
|
||||
anchor_pred_noise = get_noise_pred(anchor.prompt, anchor.neg_prompt)
|
||||
|
||||
self.network.multiplier = prompt_pair.multiplier
|
||||
|
||||
with self.network:
|
||||
self.network.multiplier = prompt_pair.multiplier
|
||||
target_latents = train_util.predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
train_util.concat_embeddings(
|
||||
positive, # unconditional
|
||||
target_class, # target
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
guidance_scale=1,
|
||||
).to("cpu", dtype=torch.float32)
|
||||
target_latents = get_noise_pred(positive, target_class)
|
||||
|
||||
# if self.logging_config.verbose:
|
||||
# self.print("target_latents:", target_latents[0, 0, :5, :5])
|
||||
|
||||
Reference in New Issue
Block a user