Big refactor of SD runner and added image generator

This commit is contained in:
Jaret Burkett
2023-08-03 14:51:25 -06:00
parent 75ec5d9292
commit 66c6f0f6f7
16 changed files with 923 additions and 430 deletions

View File

@@ -1,7 +1,6 @@
# ref:
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
import random
import time
from collections import OrderedDict
import os
from typing import Optional
@@ -14,16 +13,12 @@ from toolkit.paths import REPOS_ROOT
import sys
from toolkit.stable_diffusion_model import PromptEmbeds
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
from toolkit.train_tools import get_torch_dtype
import gc
from toolkit import train_tools
import torch
from leco import train_util, model_util
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
from .BaseSDTrainProcess import BaseSDTrainProcess
class ACTION_TYPES_SLIDER:
@@ -131,7 +126,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
if not self.slider_config.prompt_tensors:
# shuffle
random.shuffle(self.prompt_txt_list)
@@ -175,8 +169,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False):
for target in self.slider_config.targets:
prompt_list = [
f"{target.target_class}", # target_class
f"{target.target_class} {neutral}", # target_class with neutral
f"{target.target_class}", # target_class
f"{target.target_class} {neutral}", # target_class with neutral
f"{target.positive}", # positive_target
f"{target.positive} {neutral}", # positive_target with neutral
f"{target.negative}", # negative_target
@@ -320,7 +314,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
)
]
# move to cpu to save vram
# We don't need text encoder anymore, but keep it on cpu for sampling
# if text encoder is list
@@ -364,7 +357,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
loss_function = torch.nn.MSELoss()
def get_noise_pred(neg, pos, gs, cts, dn):
return self.predict_noise(
return self.sd.predict_noise(
latents=dn,
text_embeddings=train_tools.concat_prompt_embeddings(
neg, # negative prompt
@@ -391,9 +384,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
).item()
# get noise
noise = self.get_latent_noise(
noise = self.sd.get_latent_noise(
pixel_height=height,
pixel_width=width,
batch_size=self.train_config.batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
# get latents
@@ -403,7 +398,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
with self.network:
assert self.network.is_active
self.network.multiplier = multiplier * rand_weight
denoised_latents = self.diffuse_some_steps(
denoised_latents = self.sd.diffuse_some_steps(
latents, # pass simple noise latents
train_tools.concat_prompt_embeddings(
prompt_pair.positive_target, # unconditional