Slider training functioning, time to perfect it

This commit is contained in:
Jaret Burkett
2023-07-21 22:06:49 -06:00
parent ddcd9069e1
commit 596e59dd6d

View File

@@ -3,6 +3,8 @@
import time
from collections import OrderedDict
import os
from typing import List
from toolkit.kohya_model_util import load_vae
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.paths import REPOS_ROOT
@@ -99,6 +101,21 @@ class ModelConfig:
raise ValueError('name_or_path must be specified')
class SliderTargetConfig:
def __init__(self, **kwargs):
self.target_class: str = kwargs.get('target_class', None)
self.positive: str = kwargs.get('positive', None)
self.negative: str = kwargs.get('negative', None)
class SliderConfig:
def __init__(self, **kwargs):
targets = kwargs.get('targets', [])
targets = [SliderTargetConfig(**target) for target in targets]
self.targets: List[SliderTargetConfig] = targets
self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]])
class PromptSettingsOld:
def __init__(self, **kwargs):
self.target: str = kwargs.get('target', None)
@@ -113,6 +130,24 @@ class PromptSettingsOld:
self.dynamic_crops: bool = kwargs.get('dynamic_crops', False) # default is False. only used when model is XL
class EncodedPromptPair:
def __init__(
self,
target_class,
positive,
negative,
neutral,
width=512,
height=512
):
self.target_class = target_class
self.positive = positive
self.negative = negative
self.neutral = neutral
self.width = width
self.height = height
class TrainSliderProcess(BaseTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
@@ -127,10 +162,9 @@ class TrainSliderProcess(BaseTrainProcess):
self.save_config = SaveConfig(**self.get_conf('save', {}))
self.sample_config = SampleConfig(**self.get_conf('sample', {}))
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
self.slider_config = SliderConfig(**self.get_conf('slider', {}))
self.sd = None
self.prompt_settings = self.get_prompt_settings()
# added later
self.network = None
self.scheduler = None
@@ -142,14 +176,6 @@ class TrainSliderProcess(BaseTrainProcess):
param.data = -param.data
self.is_flipped = not self.is_flipped
def get_prompt_settings(self):
prompts = self.get_conf('prompts', required=True)
prompt_settings = [PromptSettingsOld(**prompt) for prompt in prompts]
# for i, prompt in enumerate(prompts):
# prompt_settings[i].fill_prompts(prompt)
return prompt_settings
def sample(self, step=None):
sample_folder = os.path.join(self.save_root, 'samples')
if not os.path.exists(sample_folder):
@@ -352,44 +378,38 @@ class TrainSliderProcess(BaseTrainProcess):
max_iterations=self.train_config.steps,
lr_min=self.train_config.lr / 100, # not sure why leco did this, but ill do it to
)
criteria = torch.nn.MSELoss()
if self.logging_config.verbose:
print("Prompts")
for settings in self.prompt_settings:
print(settings)
# debug
# debug_util.check_requires_grad(network)
# debug_util.check_training_mode(network)
loss_function = torch.nn.MSELoss()
cache = PromptEmbedsCache()
prompt_pairs: list[PromptEmbedsPair] = []
prompt_pairs: list[LatentPair] = []
# get encoded latents for our prompts
with torch.no_grad():
for settings in self.prompt_settings:
self.print(settings)
for prompt in [
settings.target,
settings.positive,
settings.neutral,
settings.unconditional,
]:
if cache[prompt] == None:
cache[prompt] = train_util.encode_prompts(
tokenizer, text_encoder, [prompt]
)
neutral = ""
for target in self.slider_config.targets:
for resolution in self.slider_config.resolutions:
width, height = resolution
for prompt in [
target.target_class,
target.positive,
target.negative,
neutral # empty neutral
]:
if cache[prompt] == None:
cache[prompt] = train_util.encode_prompts(
tokenizer, text_encoder, [prompt]
)
prompt_pairs.append(
PromptEmbedsPair(
criteria,
cache[settings.target],
cache[settings.positive],
cache[settings.unconditional],
cache[settings.neutral],
settings,
prompt_pairs.append(
EncodedPromptPair(
target_class=cache[target.target_class],
positive=cache[target.positive],
negative=cache[target.negative],
neutral=cache[neutral],
width=width,
height=height,
)
)
)
# move to cpu to save vram
# tokenizer.to("cpu")
@@ -400,7 +420,6 @@ class TrainSliderProcess(BaseTrainProcess):
self.print("Generating baseline samples before training")
self.sample(0)
self.progress_bar = tqdm(range(self.train_config.steps))
self.progress_bar = tqdm(
total=self.train_config.steps,
desc=self.job.name,
@@ -408,6 +427,29 @@ class TrainSliderProcess(BaseTrainProcess):
)
self.step_num = 0
for step in range(self.train_config.steps):
# get a random pair
prompt_pair: EncodedPromptPair = prompt_pairs[
torch.randint(0, len(prompt_pairs), (1,)).item()
]
height = prompt_pair.height
width = prompt_pair.width
positive = prompt_pair.positive
target_class = prompt_pair.target_class
neutral = prompt_pair.neutral
negative = prompt_pair.negative
# swap every other step and invert lora to spread slider
do_swap = step % 2 == 0
if do_swap:
negative = prompt_pair.positive
positive = prompt_pair.negative
# set the network in a negative weight
self.network.multiplier = -1.0
with torch.no_grad():
noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
@@ -415,34 +457,17 @@ class TrainSliderProcess(BaseTrainProcess):
optimizer.zero_grad()
prompt_pair: PromptEmbedsPair = prompt_pairs[
torch.randint(0, len(prompt_pairs), (1,)).item()
]
# 1 ~ 49 random from 1 to 49
# ger a random number of steps
timesteps_to = torch.randint(
1, self.train_config.max_denoising_steps, (1,)
).item()
height, width = (
prompt_pair.resolution,
prompt_pair.resolution,
)
if prompt_pair.dynamic_resolution:
height, width = train_util.get_random_resolution_in_bucket(
prompt_pair.resolution
)
if self.logging_config.verbose:
self.print("guidance_scale:", prompt_pair.guidance_scale)
self.print("resolution:", prompt_pair.resolution)
self.print("dynamic_resolution:", prompt_pair.dynamic_resolution)
if prompt_pair.dynamic_resolution:
self.print("bucketed resolution:", (height, width))
self.print("batch_size:", prompt_pair.batch_size)
latents = train_util.get_initial_latents(
noise_scheduler, prompt_pair.batch_size, height, width, 1
noise_scheduler,
self.train_config.batch_size,
height,
width,
1
).to(self.device_torch, dtype=dtype)
with self.network:
@@ -453,9 +478,9 @@ class TrainSliderProcess(BaseTrainProcess):
noise_scheduler,
latents, # pass simple noise latents
train_util.concat_embeddings(
prompt_pair.unconditional,
prompt_pair.target,
prompt_pair.batch_size,
positive, # unconditional
target_class, # target
self.train_config.batch_size,
),
start_timesteps=0,
total_timesteps=timesteps_to,
@@ -468,16 +493,16 @@ class TrainSliderProcess(BaseTrainProcess):
int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
]
# with network: Only empty LoRA is enabled outside with network :
positive_latents = train_util.predict_noise(
# with network: 0 weight LoRA is enabled outside "with network:"
positive_latents = train_util.predict_noise( # positive_latents
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
prompt_pair.unconditional,
prompt_pair.positive,
prompt_pair.batch_size,
positive, # unconditional
negative, # positive
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
@@ -487,9 +512,9 @@ class TrainSliderProcess(BaseTrainProcess):
current_timestep,
denoised_latents,
train_util.concat_embeddings(
prompt_pair.unconditional,
prompt_pair.neutral,
prompt_pair.batch_size,
positive, # unconditional
neutral, # neutral
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
@@ -499,16 +524,12 @@ class TrainSliderProcess(BaseTrainProcess):
current_timestep,
denoised_latents,
train_util.concat_embeddings(
prompt_pair.unconditional,
prompt_pair.unconditional,
prompt_pair.batch_size,
positive, # unconditional
positive, # unconditional
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
# if self.logging_config.verbose:
# self.print("positive_latents:", positive_latents[0, 0, :5, :5])
# self.print("neutral_latents:", neutral_latents[0, 0, :5, :5])
# self.print("unconditional_latents:", unconditional_latents[0, 0, :5, :5])
with self.network:
target_latents = train_util.predict_noise(
@@ -517,9 +538,9 @@ class TrainSliderProcess(BaseTrainProcess):
current_timestep,
denoised_latents,
train_util.concat_embeddings(
prompt_pair.unconditional,
prompt_pair.target,
prompt_pair.batch_size,
positive, # unconditional
target_class, # target
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
@@ -531,12 +552,23 @@ class TrainSliderProcess(BaseTrainProcess):
neutral_latents.requires_grad = False
unconditional_latents.requires_grad = False
loss = prompt_pair.loss(
target_latents=target_latents,
positive_latents=positive_latents,
neutral_latents=neutral_latents,
unconditional_latents=unconditional_latents,
erase = True
guidance_scale = 1.0
offset = guidance_scale * (positive_latents - unconditional_latents)
offset_neutral = neutral_latents
if erase:
offset_neutral -= offset
else:
# enhance
offset_neutral += offset
loss = loss_function(
target_latents,
offset_neutral,
)
loss_float = loss.item()
if self.train_config.optimizer.startswith('dadaptation'):
learning_rate = (
@@ -561,6 +593,9 @@ class TrainSliderProcess(BaseTrainProcess):
)
flush()
# reset network
self.network.multiplier = 1.0
# don't do on first step
if self.step_num != self.start_step:
# pause progress bar
@@ -594,8 +629,11 @@ class TrainSliderProcess(BaseTrainProcess):
# end of step
self.step_num = step
print("")
self.save()
del (
unet,
noise_scheduler,