mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Information trainer
This commit is contained in:
@@ -13,7 +13,8 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, KDPM2DiscreteScheduler
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, KDPM2DiscreteScheduler, PNDMScheduler, \
|
||||
DDIMScheduler, DDPMScheduler
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
|
||||
@@ -38,8 +39,9 @@ VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
||||
|
||||
|
||||
class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None):
|
||||
super().__init__(process_id, job, config)
|
||||
self.custom_pipeline = custom_pipeline
|
||||
self.step_num = 0
|
||||
self.start_step = 0
|
||||
self.device = self.get_conf('device', self.job.device)
|
||||
@@ -271,6 +273,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
)
|
||||
|
||||
self.print(f"Saved to {file_path}")
|
||||
self.clean_up_saves()
|
||||
|
||||
# Called before the model is loaded
|
||||
def hook_before_model_load(self):
|
||||
@@ -467,18 +470,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# TODO handle other schedulers
|
||||
sch = KDPM2DiscreteScheduler
|
||||
# do our own scheduler
|
||||
scheduler = KDPM2DiscreteScheduler(
|
||||
scheduler = sch(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.00085,
|
||||
beta_end=0.0120,
|
||||
beta_schedule="scaled_linear",
|
||||
)
|
||||
if self.model_config.is_xl:
|
||||
pipe = CustomStableDiffusionXLPipeline.from_single_file(
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
else:
|
||||
pipln = CustomStableDiffusionXLPipeline
|
||||
pipe = pipln.from_single_file(
|
||||
self.model_config.name_or_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='dpm',
|
||||
scheduler_type='ddpm',
|
||||
device=self.device_torch,
|
||||
).to(self.device_torch)
|
||||
|
||||
@@ -490,7 +499,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
text_encoder.eval()
|
||||
text_encoder = text_encoders
|
||||
else:
|
||||
pipe = CustomStableDiffusionPipeline.from_single_file(
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
else:
|
||||
pipln = CustomStableDiffusionPipeline
|
||||
pipe = pipln.from_single_file(
|
||||
self.model_config.name_or_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='dpm',
|
||||
@@ -614,7 +627,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
if self.has_first_sample_requested:
|
||||
self.print("Generating first sample from first sample config")
|
||||
self.sample(0, is_first=False)
|
||||
self.sample(0, is_first=True)
|
||||
|
||||
# sample first
|
||||
if self.train_config.skip_first_sample:
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections import OrderedDict
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -14,6 +15,7 @@ from toolkit.paths import REPOS_ROOT
|
||||
import sys
|
||||
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
from toolkit.train_pipelines import TransferStableDiffusionXLPipeline
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
@@ -61,7 +63,8 @@ class PromptEmbedsCache:
|
||||
|
||||
class TrainSDRescaleProcess(BaseSDTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
# pass our custom pipeline to super so it sets it up
|
||||
super().__init__(process_id, job, config, custom_pipeline=TransferStableDiffusionXLPipeline)
|
||||
self.step_num = 0
|
||||
self.start_step = 0
|
||||
self.device = self.get_conf('device', self.job.device)
|
||||
@@ -173,9 +176,6 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
|
||||
if prompt is None:
|
||||
raise ValueError(f"Prompt {prompt_txt} is not in cache")
|
||||
|
||||
noise_scheduler = self.sd.noise_scheduler
|
||||
optimizer = self.optimizer
|
||||
lr_scheduler = self.lr_scheduler
|
||||
loss_function = torch.nn.MSELoss()
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -189,13 +189,6 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
|
||||
timesteps_to = torch.randint(
|
||||
1, self.train_config.max_denoising_steps, (1,)
|
||||
).item()
|
||||
absolute_total_timesteps = 1000
|
||||
|
||||
max_len_timestep_str = len(str(self.train_config.max_denoising_steps))
|
||||
# pad with spaces
|
||||
timestep_str = str(timesteps_to).rjust(max_len_timestep_str, " ")
|
||||
new_description = f"{self.job.name} ts: {timestep_str}"
|
||||
self.progress_bar.set_description(new_description)
|
||||
|
||||
# get noise
|
||||
latents = self.get_latent_noise(
|
||||
@@ -203,105 +196,71 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
|
||||
pixel_width=self.rescale_config.from_resolution,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
denoised_fraction = timesteps_to / absolute_total_timesteps
|
||||
self.sd.pipeline.to(self.device_torch)
|
||||
torch.set_default_device(self.device_torch)
|
||||
|
||||
# turn off progress bar
|
||||
self.sd.pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
pre_train = False
|
||||
# get random guidance scale from 1.0 to 10.0
|
||||
guidance_scale = torch.rand(1).item() * 9.0 + 1.0
|
||||
|
||||
if not pre_train:
|
||||
# partially denoise the latents
|
||||
denoised_latents = self.sd.pipeline(
|
||||
num_inference_steps=self.train_config.max_denoising_steps,
|
||||
denoising_end=denoised_fraction,
|
||||
latents=latents,
|
||||
prompt_embeds=prompt.text_embeds,
|
||||
negative_prompt_embeds=neutral.text_embeds,
|
||||
pooled_prompt_embeds=prompt.pooled_embeds,
|
||||
negative_pooled_prompt_embeds=neutral.pooled_embeds,
|
||||
output_type="latent",
|
||||
num_images_per_prompt=self.train_config.batch_size,
|
||||
guidance_scale=3,
|
||||
).images.to(self.device_torch, dtype=dtype)
|
||||
current_timestep = timesteps_to
|
||||
loss_arr = []
|
||||
|
||||
else:
|
||||
denoised_latents = latents
|
||||
current_timestep = 1
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
1000
|
||||
)
|
||||
max_len_timestep_str = len(str(self.train_config.max_denoising_steps))
|
||||
# pad with spaces
|
||||
timestep_str = str(timesteps_to).rjust(max_len_timestep_str, " ")
|
||||
new_description = f"{self.job.name} ts: {timestep_str}"
|
||||
self.progress_bar.set_description(new_description)
|
||||
|
||||
from_prediction = self.sd.pipeline.predict_noise(
|
||||
latents=denoised_latents,
|
||||
def pre_condition_callback(target_pred, input_latents):
|
||||
# handle any manipulations before feeding to our network
|
||||
reduced_pred = self.reduce_size_fn(target_pred)
|
||||
reduced_latents = self.reduce_size_fn(input_latents)
|
||||
self.optimizer.zero_grad()
|
||||
return reduced_pred, reduced_latents
|
||||
|
||||
def each_step_callback(noise_target, noise_train_pred):
|
||||
noise_target.requires_grad = False
|
||||
loss = loss_function(noise_target, noise_train_pred)
|
||||
loss_arr.append(loss.item())
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
self.lr_scheduler.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# run the pipeline
|
||||
self.sd.pipeline.transfer_diffuse(
|
||||
num_inference_steps=timesteps_to,
|
||||
latents=latents,
|
||||
prompt_embeds=prompt.text_embeds,
|
||||
negative_prompt_embeds=neutral.text_embeds,
|
||||
pooled_prompt_embeds=prompt.pooled_embeds,
|
||||
negative_pooled_prompt_embeds=neutral.pooled_embeds,
|
||||
timestep=current_timestep,
|
||||
guidance_scale=1,
|
||||
output_type="latent",
|
||||
num_images_per_prompt=self.train_config.batch_size,
|
||||
# predict_noise=True,
|
||||
num_inference_steps=1000,
|
||||
guidance_scale=guidance_scale,
|
||||
network=self.network,
|
||||
target_unet=self.sd.unet,
|
||||
pre_condition_callback=pre_condition_callback,
|
||||
each_step_callback=each_step_callback,
|
||||
)
|
||||
|
||||
reduced_from_prediction = self.reduce_size_fn(from_prediction)
|
||||
|
||||
# get noise prediction at reduced scale
|
||||
to_denoised_latents = self.reduce_size_fn(denoised_latents).to(self.device_torch, dtype=dtype)
|
||||
|
||||
# start gradient
|
||||
optimizer.zero_grad()
|
||||
self.network.multiplier = 1.0
|
||||
with self.network:
|
||||
assert self.network.is_active is True
|
||||
to_prediction = self.sd.pipeline.predict_noise(
|
||||
latents=to_denoised_latents,
|
||||
prompt_embeds=prompt.text_embeds,
|
||||
negative_prompt_embeds=neutral.text_embeds,
|
||||
pooled_prompt_embeds=prompt.pooled_embeds,
|
||||
negative_pooled_prompt_embeds=neutral.pooled_embeds,
|
||||
timestep=current_timestep,
|
||||
guidance_scale=1,
|
||||
num_images_per_prompt=self.train_config.batch_size,
|
||||
# predict_noise=True,
|
||||
num_inference_steps=1000,
|
||||
)
|
||||
|
||||
reduced_from_prediction.requires_grad = False
|
||||
from_prediction.requires_grad = False
|
||||
|
||||
loss = loss_function(
|
||||
reduced_from_prediction,
|
||||
to_prediction,
|
||||
)
|
||||
|
||||
loss_float = loss.item()
|
||||
|
||||
loss = loss.to(self.device_torch)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
del (
|
||||
reduced_from_prediction,
|
||||
from_prediction,
|
||||
to_denoised_latents,
|
||||
to_prediction,
|
||||
latents,
|
||||
)
|
||||
flush()
|
||||
|
||||
# reset network
|
||||
self.network.multiplier = 1.0
|
||||
|
||||
# average losses
|
||||
s = 0
|
||||
for num in loss_arr:
|
||||
s += num
|
||||
|
||||
avg_loss = s / len(loss_arr)
|
||||
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss_float},
|
||||
{'loss': avg_loss},
|
||||
)
|
||||
|
||||
return loss_dict
|
||||
|
||||
Reference in New Issue
Block a user