mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-10 15:39:57 +00:00
Various bug fixes and improvements
This commit is contained in:
@@ -2,11 +2,12 @@ import copy
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional, Union, List
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
from toolkit.data_loader import PairedImageDataset
|
||||
from toolkit.prompt_utils import concat_prompt_embeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
import gc
|
||||
from toolkit import train_tools
|
||||
@@ -80,34 +81,16 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
imgs, prompts = batch
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype)
|
||||
|
||||
# split batched images in half so left is negative and right is positive
|
||||
negative_images, positive_images = torch.chunk(imgs, 2, dim=3)
|
||||
|
||||
positive_latents = self.sd.encode_images(positive_images)
|
||||
negative_latents = self.sd.encode_images(negative_images)
|
||||
|
||||
height = positive_images.shape[2]
|
||||
width = positive_images.shape[3]
|
||||
batch_size = positive_images.shape[0]
|
||||
|
||||
# encode the images
|
||||
positive_latents = self.sd.vae.encode(positive_images).latent_dist.sample()
|
||||
positive_latents = positive_latents * 0.18215
|
||||
negative_latents = self.sd.vae.encode(negative_images).latent_dist.sample()
|
||||
negative_latents = negative_latents * 0.18215
|
||||
|
||||
embedding_list = []
|
||||
negative_embedding_list = []
|
||||
# embed the prompts
|
||||
for prompt in prompts:
|
||||
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
|
||||
embedding_list.append(embedding)
|
||||
# just empty for now
|
||||
# todo cache this?
|
||||
negative_embed = self.sd.encode_prompt('').to(self.device_torch, dtype=dtype)
|
||||
negative_embedding_list.append(negative_embed)
|
||||
|
||||
conditional_embeds = concat_prompt_embeds(embedding_list)
|
||||
unconditional_embeds = concat_prompt_embeds(negative_embedding_list)
|
||||
|
||||
if self.train_config.gradient_checkpointing:
|
||||
# may get disabled elsewhere
|
||||
self.sd.unet.enable_gradient_checkpointing()
|
||||
@@ -115,26 +98,12 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
noise_scheduler = self.sd.noise_scheduler
|
||||
optimizer = self.optimizer
|
||||
lr_scheduler = self.lr_scheduler
|
||||
loss_function = torch.nn.MSELoss()
|
||||
|
||||
def get_noise_pred(neg, pos, gs, cts, dn):
|
||||
return self.sd.predict_noise(
|
||||
latents=dn,
|
||||
text_embeddings=train_tools.concat_prompt_embeddings(
|
||||
neg, # negative prompt
|
||||
pos, # positive prompt
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
timestep=cts,
|
||||
guidance_scale=gs,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
)
|
||||
|
||||
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
|
||||
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# get noise
|
||||
@@ -147,6 +116,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
|
||||
if do_mirror_loss:
|
||||
# mirror the noise
|
||||
# torch shape is [batch, channels, height, width]
|
||||
noise_negative = torch.flip(noise_positive.clone(), dims=[3])
|
||||
else:
|
||||
noise_negative = noise_positive.clone()
|
||||
@@ -159,8 +129,6 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0)
|
||||
noise = torch.cat([noise_positive, noise_negative], dim=0)
|
||||
timesteps = torch.cat([timesteps, timesteps], dim=0)
|
||||
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||
unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds])
|
||||
network_multiplier = [1.0, -1.0]
|
||||
|
||||
flush()
|
||||
@@ -170,22 +138,31 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
loss_mirror_float = None
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
noisy_latents.requires_grad = False
|
||||
|
||||
# if training text encoder enable grads, else do context of no grad
|
||||
with torch.set_grad_enabled(self.train_config.train_text_encoder):
|
||||
# text encoding
|
||||
embedding_list = []
|
||||
# embed the prompts
|
||||
for prompt in prompts:
|
||||
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
|
||||
embedding_list.append(embedding)
|
||||
conditional_embeds = concat_prompt_embeds(embedding_list)
|
||||
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||
|
||||
with self.network:
|
||||
assert self.network.is_active
|
||||
loss_list = []
|
||||
|
||||
# do positive first
|
||||
self.network.multiplier = network_multiplier
|
||||
|
||||
noise_pred = get_noise_pred(
|
||||
unconditional_embeds,
|
||||
conditional_embeds,
|
||||
1,
|
||||
timesteps,
|
||||
noisy_latents
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents,
|
||||
conditional_embeddings=conditional_embeds,
|
||||
timestep=timesteps,
|
||||
)
|
||||
|
||||
if self.sd.is_v2: # check is vpred, don't want to track it down right now
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
@@ -199,7 +176,6 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
loss = loss.mean()
|
||||
loss_slide_float = loss.item()
|
||||
|
||||
|
||||
if do_mirror_loss:
|
||||
noise_pred_pos, noise_pred_neg = torch.chunk(noise_pred, 2, dim=0)
|
||||
# mirror the negative
|
||||
@@ -221,7 +197,6 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
|
||||
# reset network
|
||||
self.network.multiplier = 1.0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user