mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Big refactor of SD runner and added image generator
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# ref:
|
||||
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
|
||||
import random
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
from typing import Optional
|
||||
@@ -14,16 +13,12 @@ from toolkit.paths import REPOS_ROOT
|
||||
import sys
|
||||
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
import gc
|
||||
from toolkit import train_tools
|
||||
|
||||
import torch
|
||||
from leco import train_util, model_util
|
||||
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
|
||||
from .BaseSDTrainProcess import BaseSDTrainProcess
|
||||
|
||||
|
||||
class ACTION_TYPES_SLIDER:
|
||||
@@ -131,7 +126,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
|
||||
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
|
||||
|
||||
|
||||
if not self.slider_config.prompt_tensors:
|
||||
# shuffle
|
||||
random.shuffle(self.prompt_txt_list)
|
||||
@@ -175,8 +169,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False):
|
||||
for target in self.slider_config.targets:
|
||||
prompt_list = [
|
||||
f"{target.target_class}", # target_class
|
||||
f"{target.target_class} {neutral}", # target_class with neutral
|
||||
f"{target.target_class}", # target_class
|
||||
f"{target.target_class} {neutral}", # target_class with neutral
|
||||
f"{target.positive}", # positive_target
|
||||
f"{target.positive} {neutral}", # positive_target with neutral
|
||||
f"{target.negative}", # negative_target
|
||||
@@ -320,7 +314,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
# move to cpu to save vram
|
||||
# We don't need text encoder anymore, but keep it on cpu for sampling
|
||||
# if text encoder is list
|
||||
@@ -364,7 +357,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
loss_function = torch.nn.MSELoss()
|
||||
|
||||
def get_noise_pred(neg, pos, gs, cts, dn):
|
||||
return self.predict_noise(
|
||||
return self.sd.predict_noise(
|
||||
latents=dn,
|
||||
text_embeddings=train_tools.concat_prompt_embeddings(
|
||||
neg, # negative prompt
|
||||
@@ -391,9 +384,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
).item()
|
||||
|
||||
# get noise
|
||||
noise = self.get_latent_noise(
|
||||
noise = self.sd.get_latent_noise(
|
||||
pixel_height=height,
|
||||
pixel_width=width,
|
||||
batch_size=self.train_config.batch_size,
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
# get latents
|
||||
@@ -403,7 +398,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
with self.network:
|
||||
assert self.network.is_active
|
||||
self.network.multiplier = multiplier * rand_weight
|
||||
denoised_latents = self.diffuse_some_steps(
|
||||
denoised_latents = self.sd.diffuse_some_steps(
|
||||
latents, # pass simple noise latents
|
||||
train_tools.concat_prompt_embeddings(
|
||||
prompt_pair.positive_target, # unconditional
|
||||
|
||||
Reference in New Issue
Block a user