mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-22 21:33:59 +00:00
Tons of bug fixes and improvements to special training. Fixed slider training.
This commit is contained in:
@@ -2,13 +2,15 @@ from collections import OrderedDict
|
||||
from typing import Union, Literal, List
|
||||
from diffusers import T2IAdapter
|
||||
|
||||
import torch.functional as F
|
||||
from toolkit import train_tools
|
||||
from toolkit.basic import value_map, adain, get_mean_std
|
||||
from toolkit.config_modules import GuidanceConfig
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
|
||||
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
|
||||
from toolkit.image_utils import show_tensors, show_latents
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \
|
||||
apply_learnable_snr_gos, LearnableSNRGamma
|
||||
@@ -35,6 +37,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.assistant_adapter: Union['T2IAdapter', None]
|
||||
self.do_prior_prediction = False
|
||||
self.do_long_prompts = False
|
||||
self.do_guided_loss = False
|
||||
if self.train_config.inverted_mask_prior:
|
||||
self.do_prior_prediction = True
|
||||
|
||||
@@ -186,6 +189,33 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
**kwargs
|
||||
):
|
||||
loss = get_guidance_loss(
|
||||
noisy_latents=noisy_latents,
|
||||
conditional_embeds=conditional_embeds,
|
||||
match_adapter_assist=match_adapter_assist,
|
||||
network_weight_list=network_weight_list,
|
||||
timesteps=timesteps,
|
||||
pred_kwargs=pred_kwargs,
|
||||
batch=batch,
|
||||
noise=noise,
|
||||
sd=self.sd,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def get_guided_loss_targeted_polarity(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
**kwargs
|
||||
):
|
||||
with torch.no_grad():
|
||||
# Perform targeted guidance (working title)
|
||||
@@ -194,23 +224,28 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
|
||||
unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
unconditional_diff = unconditional_latents - conditional_latents
|
||||
conditional_diff = conditional_latents - unconditional_latents
|
||||
mean_latents = (conditional_latents + unconditional_latents) / 2.0
|
||||
|
||||
unconditional_diff = (unconditional_latents - mean_latents)
|
||||
conditional_diff = (conditional_latents - mean_latents)
|
||||
|
||||
# we need to determine the amount of signal and noise that would be present at the current timestep
|
||||
conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
|
||||
unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
|
||||
# conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
|
||||
# unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
|
||||
# unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps)
|
||||
# conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps)
|
||||
# unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps)
|
||||
|
||||
target_noise = noise + unconditional_signal
|
||||
# target_noise = noise + unconditional_signal
|
||||
|
||||
conditional_noisy_latents = self.sd.add_noise(
|
||||
unconditional_latents + conditional_signal,
|
||||
target_noise,
|
||||
mean_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
unconditional_noisy_latents = self.sd.add_noise(
|
||||
unconditional_latents,
|
||||
mean_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
@@ -229,31 +264,210 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
**pred_kwargs # adapter residuals in here
|
||||
).detach()
|
||||
|
||||
# turn the LoRA network back on.
|
||||
self.sd.unet.train()
|
||||
self.network.is_active = True
|
||||
self.network.multiplier = network_weight_list
|
||||
# double up everything to run it through all at once
|
||||
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||
cat_latents = torch.cat([conditional_noisy_latents, conditional_noisy_latents], dim=0)
|
||||
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
|
||||
|
||||
# since we are dividing the polarity from the middle out, we need to double our network
|
||||
# weights on training since the convergent point will be at half network strength
|
||||
|
||||
negative_network_weights = [weight * -2.0 for weight in network_weight_list]
|
||||
positive_network_weights = [weight * 2.0 for weight in network_weight_list]
|
||||
cat_network_weight_list = positive_network_weights + negative_network_weights
|
||||
|
||||
# turn the LoRA network back on.
|
||||
self.sd.unet.train()
|
||||
self.network.is_active = True
|
||||
|
||||
self.network.multiplier = cat_network_weight_list
|
||||
|
||||
# do our prediction with LoRA active on the scaled guidance latents
|
||||
prediction = self.sd.predict_noise(
|
||||
latents=conditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
latents=cat_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=cat_timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
|
||||
# remove baseline from our prediction to extract our differential prediction
|
||||
prediction = prediction - baseline_prediction
|
||||
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(
|
||||
prediction.float(),
|
||||
unconditional_signal.float(),
|
||||
pred_pos = pred_pos - baseline_prediction
|
||||
pred_neg = pred_neg - baseline_prediction
|
||||
|
||||
pred_loss = torch.nn.functional.mse_loss(
|
||||
pred_pos.float(),
|
||||
unconditional_diff.float(),
|
||||
reduction="none"
|
||||
)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
pred_loss = pred_loss.mean([1, 2, 3])
|
||||
|
||||
loss = self.apply_snr(loss, timesteps)
|
||||
pred_neg_loss = torch.nn.functional.mse_loss(
|
||||
pred_neg.float(),
|
||||
conditional_diff.float(),
|
||||
reduction="none"
|
||||
)
|
||||
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
|
||||
|
||||
loss = (pred_loss + pred_neg_loss) / 2.0
|
||||
|
||||
# loss = self.apply_snr(loss, timesteps)
|
||||
loss = loss.mean()
|
||||
loss.backward()
|
||||
|
||||
# detach it so parent class can run backward on no grads without throwing error
|
||||
loss = loss.detach()
|
||||
loss.requires_grad_(True)
|
||||
|
||||
return loss
|
||||
|
||||
def get_guided_loss_masked_polarity(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
**kwargs
|
||||
):
|
||||
with torch.no_grad():
|
||||
# Perform targeted guidance (working title)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
conditional_latents = batch.latents.to(self.device_torch, dtype=dtype).detach()
|
||||
unconditional_latents = batch.unconditional_latents.to(self.device_torch, dtype=dtype).detach()
|
||||
inverse_latents = unconditional_latents - (conditional_latents - unconditional_latents)
|
||||
|
||||
mean_latents = (conditional_latents + unconditional_latents) / 2.0
|
||||
|
||||
# unconditional_diff = (unconditional_latents - mean_latents)
|
||||
# conditional_diff = (conditional_latents - mean_latents)
|
||||
|
||||
# we need to determine the amount of signal and noise that would be present at the current timestep
|
||||
# conditional_signal = self.sd.add_noise(conditional_diff, torch.zeros_like(noise), timesteps)
|
||||
# unconditional_signal = self.sd.add_noise(torch.zeros_like(noise), unconditional_diff, timesteps)
|
||||
# unconditional_signal = self.sd.add_noise(unconditional_diff, torch.zeros_like(noise), timesteps)
|
||||
# conditional_blend = self.sd.add_noise(conditional_latents, unconditional_latents, timesteps)
|
||||
# unconditional_blend = self.sd.add_noise(unconditional_latents, conditional_latents, timesteps)
|
||||
|
||||
# make a differential mask
|
||||
differential_mask = torch.abs(conditional_latents - unconditional_latents)
|
||||
max_differential = \
|
||||
differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
|
||||
differential_scaler = 1.0 / max_differential
|
||||
differential_mask = differential_mask * differential_scaler
|
||||
spread_point = 0.1
|
||||
# adjust mask to amplify the differential at 0.1
|
||||
differential_mask = ((differential_mask - spread_point) * 10.0) + spread_point
|
||||
# clip it
|
||||
differential_mask = torch.clamp(differential_mask, 0.0, 1.0)
|
||||
|
||||
# target_noise = noise + unconditional_signal
|
||||
|
||||
conditional_noisy_latents = self.sd.add_noise(
|
||||
conditional_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
unconditional_noisy_latents = self.sd.add_noise(
|
||||
unconditional_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
inverse_noisy_latents = self.sd.add_noise(
|
||||
inverse_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
# Disable the LoRA network so we can predict parent network knowledge without it
|
||||
self.network.is_active = False
|
||||
self.sd.unet.eval()
|
||||
|
||||
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
||||
# This acts as our control to preserve the unaltered parts of the image.
|
||||
# baseline_prediction = self.sd.predict_noise(
|
||||
# latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
# conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
# timestep=timesteps,
|
||||
# guidance_scale=1.0,
|
||||
# **pred_kwargs # adapter residuals in here
|
||||
# ).detach()
|
||||
|
||||
# double up everything to run it through all at once
|
||||
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
|
||||
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
|
||||
|
||||
# since we are dividing the polarity from the middle out, we need to double our network
|
||||
# weights on training since the convergent point will be at half network strength
|
||||
|
||||
negative_network_weights = [weight * -1.0 for weight in network_weight_list]
|
||||
positive_network_weights = [weight * 1.0 for weight in network_weight_list]
|
||||
cat_network_weight_list = positive_network_weights + negative_network_weights
|
||||
|
||||
# turn the LoRA network back on.
|
||||
self.sd.unet.train()
|
||||
self.network.is_active = True
|
||||
|
||||
self.network.multiplier = cat_network_weight_list
|
||||
|
||||
# do our prediction with LoRA active on the scaled guidance latents
|
||||
prediction = self.sd.predict_noise(
|
||||
latents=cat_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=cat_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=cat_timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
|
||||
pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
|
||||
|
||||
# create a loss to balance the mean to 0 between the two predictions
|
||||
differential_mean_pred_loss = torch.abs(pred_pos - pred_neg).mean([1, 2, 3]) ** 2.0
|
||||
|
||||
# pred_pos = pred_pos - baseline_prediction
|
||||
# pred_neg = pred_neg - baseline_prediction
|
||||
|
||||
pred_loss = torch.nn.functional.mse_loss(
|
||||
pred_pos.float(),
|
||||
noise.float(),
|
||||
reduction="none"
|
||||
)
|
||||
# apply mask
|
||||
pred_loss = pred_loss * (1.0 + differential_mask)
|
||||
pred_loss = pred_loss.mean([1, 2, 3])
|
||||
|
||||
pred_neg_loss = torch.nn.functional.mse_loss(
|
||||
pred_neg.float(),
|
||||
noise.float(),
|
||||
reduction="none"
|
||||
)
|
||||
# apply inverse mask
|
||||
pred_neg_loss = pred_neg_loss * (1.0 - differential_mask)
|
||||
pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
|
||||
|
||||
# make a loss to balance to losses of the pos and neg so they are equal
|
||||
# differential_mean_loss_loss = torch.abs(pred_loss - pred_neg_loss)
|
||||
#
|
||||
# differential_mean_loss = differential_mean_pred_loss + differential_mean_loss_loss
|
||||
#
|
||||
# # add a multiplier to balancing losses to make them the top priority
|
||||
# differential_mean_loss = differential_mean_loss
|
||||
|
||||
# remove the grads from the negative as it is only a balancing loss
|
||||
# pred_neg_loss = pred_neg_loss.detach()
|
||||
|
||||
# loss = pred_loss + pred_neg_loss + differential_mean_loss
|
||||
loss = pred_loss + pred_neg_loss
|
||||
|
||||
# loss = self.apply_snr(loss, timesteps)
|
||||
loss = loss.mean()
|
||||
loss.backward()
|
||||
|
||||
@@ -556,7 +770,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
self.before_unet_predict()
|
||||
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
||||
if batch.unconditional_latents is not None:
|
||||
if batch.unconditional_latents is not None or self.do_guided_loss:
|
||||
# do guided loss
|
||||
loss = self.get_guided_loss(
|
||||
noisy_latents=noisy_latents,
|
||||
|
||||
@@ -293,6 +293,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# will end in safetensors or pt
|
||||
embed_files = [f for f in embed_items if f.endswith('.safetensors') or f.endswith('.pt')]
|
||||
|
||||
# check for critic files
|
||||
critic_pattern = f"CRITIC_{self.job.name}_*"
|
||||
critic_items = glob.glob(os.path.join(self.save_root, critic_pattern))
|
||||
|
||||
# Sort the lists by creation time if they are not empty
|
||||
if safetensors_files:
|
||||
safetensors_files.sort(key=os.path.getctime)
|
||||
@@ -302,6 +306,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
directories.sort(key=os.path.getctime)
|
||||
if embed_files:
|
||||
embed_files.sort(key=os.path.getctime)
|
||||
if critic_items:
|
||||
critic_items.sort(key=os.path.getctime)
|
||||
|
||||
# Combine and sort the lists
|
||||
combined_items = safetensors_files + directories + pt_files
|
||||
@@ -313,8 +319,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
pt_files_to_remove = pt_files[:-self.save_config.max_step_saves_to_keep] if pt_files else []
|
||||
directories_to_remove = directories[:-self.save_config.max_step_saves_to_keep] if directories else []
|
||||
embeddings_to_remove = embed_files[:-self.save_config.max_step_saves_to_keep] if embed_files else []
|
||||
critic_to_remove = critic_items[:-self.save_config.max_step_saves_to_keep] if critic_items else []
|
||||
|
||||
items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove
|
||||
items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove + critic_to_remove
|
||||
|
||||
# remove all but the latest max_step_saves_to_keep
|
||||
# items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
|
||||
@@ -1041,8 +1048,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
conv_lora_dim=self.network_config.conv,
|
||||
conv_alpha=self.network_config.conv_alpha,
|
||||
is_sdxl=self.model_config.is_xl,
|
||||
is_sdxl=self.model_config.is_xl or self.model_config.is_ssd,
|
||||
is_v2=self.model_config.is_v2,
|
||||
is_ssd=self.model_config.is_ssd,
|
||||
dropout=self.network_config.dropout,
|
||||
use_text_encoder_1=self.model_config.use_text_encoder_1,
|
||||
use_text_encoder_2=self.model_config.use_text_encoder_2,
|
||||
|
||||
@@ -371,7 +371,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
|
||||
# ger a random number of steps
|
||||
timesteps_to = torch.randint(
|
||||
1, self.train_config.max_denoising_steps, (1,)
|
||||
1, self.train_config.max_denoising_steps - 1, (1,)
|
||||
).item()
|
||||
|
||||
# get noise
|
||||
@@ -389,7 +389,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
assert not self.network.is_active
|
||||
self.sd.unet.eval()
|
||||
# pass the multiplier list to the network
|
||||
self.network.multiplier = prompt_pair.multiplier_list
|
||||
# double up since we are doing cfg
|
||||
self.network.multiplier = prompt_pair.multiplier_list + prompt_pair.multiplier_list
|
||||
denoised_latents = self.sd.diffuse_some_steps(
|
||||
latents, # pass simple noise latents
|
||||
train_tools.concat_prompt_embeddings(
|
||||
@@ -507,7 +508,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
for anchor_chunk, denoised_latent_chunk, anchor_target_noise_chunk in zip(
|
||||
anchor_chunks, denoised_latent_chunks, anchor_target_noise_chunks
|
||||
):
|
||||
self.network.multiplier = anchor_chunk.multiplier_list
|
||||
self.network.multiplier = anchor_chunk.multiplier_list + anchor_chunk.multiplier_list
|
||||
|
||||
anchor_pred_noise = get_noise_pred(
|
||||
anchor_chunk.neg_prompt, anchor_chunk.prompt, 1, current_timestep, denoised_latent_chunk
|
||||
@@ -582,7 +583,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
mask_multiplier_chunks,
|
||||
unmasked_target_chunks
|
||||
):
|
||||
self.network.multiplier = prompt_pair_chunk.multiplier_list
|
||||
self.network.multiplier = prompt_pair_chunk.multiplier_list + prompt_pair_chunk.multiplier_list
|
||||
target_latents = get_noise_pred(
|
||||
prompt_pair_chunk.positive_target,
|
||||
prompt_pair_chunk.target_class,
|
||||
@@ -611,6 +612,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
offset_neutral = neutral_latents_chunk
|
||||
# offsets are already adjusted on a per-batch basis
|
||||
offset_neutral += offset
|
||||
offset_neutral = offset_neutral.detach().requires_grad_(False)
|
||||
|
||||
# 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing
|
||||
loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none")
|
||||
|
||||
20
scripts/generate_sampler_step_scales.py
Normal file
20
scripts/generate_sampler_step_scales.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import argparse
|
||||
import torch
|
||||
import os
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import sys
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
# add project root to path
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
|
||||
SAMPLER_SCALES_ROOT = os.path.join(PROJECT_ROOT, 'toolkit', 'samplers_scales')
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Process some images.')
|
||||
add_arg = parser.add_argument
|
||||
add_arg('--model', type=str, required=True, help='Path to model')
|
||||
add_arg('--sampler', type=str, required=True, help='Name of sampler')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional, Literal, Union
|
||||
from typing import List, Optional, Literal, Union, TYPE_CHECKING
|
||||
import random
|
||||
|
||||
import torch
|
||||
@@ -11,6 +11,8 @@ ImgExt = Literal['jpg', 'png', 'webp']
|
||||
|
||||
SaveFormat = Literal['safetensors', 'diffusers']
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.guidance import GuidanceType
|
||||
|
||||
class SaveConfig:
|
||||
def __init__(self, **kwargs):
|
||||
@@ -400,6 +402,7 @@ class DatasetConfig:
|
||||
if legacy_caption_type:
|
||||
self.caption_ext = legacy_caption_type
|
||||
self.caption_type = self.caption_ext
|
||||
self.guidance_type: GuidanceType = kwargs.get('guidance_type', 'targeted_polarity')
|
||||
|
||||
|
||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import torch
|
||||
from typing import Literal
|
||||
|
||||
from toolkit.basic import value_map
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
@@ -8,13 +10,16 @@ from toolkit.train_tools import get_torch_dtype
|
||||
GuidanceType = Literal["targeted", "polarity", "targeted_polarity"]
|
||||
|
||||
DIFFERENTIAL_SCALER = 0.2
|
||||
|
||||
|
||||
# DIFFERENTIAL_SCALER = 0.25
|
||||
|
||||
|
||||
def get_differential_mask(
|
||||
conditional_latents: torch.Tensor,
|
||||
unconditional_latents: torch.Tensor,
|
||||
threshold: float = 0.2
|
||||
threshold: float = 0.2,
|
||||
gradient: bool = False,
|
||||
):
|
||||
# make a differential mask
|
||||
differential_mask = torch.abs(conditional_latents - unconditional_latents)
|
||||
@@ -23,12 +28,27 @@ def get_differential_mask(
|
||||
differential_scaler = 1.0 / max_differential
|
||||
differential_mask = differential_mask * differential_scaler
|
||||
|
||||
# make everything less than 0.2 be 0.0 and everything else be 1.0
|
||||
differential_mask = torch.where(
|
||||
differential_mask < threshold,
|
||||
torch.zeros_like(differential_mask),
|
||||
torch.ones_like(differential_mask)
|
||||
)
|
||||
if gradient:
|
||||
# wew need to scale it to 0-1
|
||||
# differential_mask = differential_mask - differential_mask.min()
|
||||
# differential_mask = differential_mask / differential_mask.max()
|
||||
# add 0.2 threshold to both sides and clip
|
||||
differential_mask = value_map(
|
||||
differential_mask,
|
||||
differential_mask.min(),
|
||||
differential_mask.max(),
|
||||
0 - threshold,
|
||||
1 + threshold
|
||||
)
|
||||
differential_mask = torch.clamp(differential_mask, 0.0, 1.0)
|
||||
else:
|
||||
|
||||
# make everything less than 0.2 be 0.0 and everything else be 1.0
|
||||
differential_mask = torch.where(
|
||||
differential_mask < threshold,
|
||||
torch.zeros_like(differential_mask),
|
||||
torch.ones_like(differential_mask)
|
||||
)
|
||||
return differential_mask
|
||||
|
||||
|
||||
@@ -47,7 +67,6 @@ def get_targeted_polarity_loss(
|
||||
dtype = get_torch_dtype(sd.torch_dtype)
|
||||
device = sd.device_torch
|
||||
with torch.no_grad():
|
||||
|
||||
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
||||
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
||||
|
||||
@@ -164,7 +183,7 @@ def get_targeted_polarity_loss(
|
||||
|
||||
return loss
|
||||
|
||||
# This targets only the positive differential
|
||||
|
||||
# targeted
|
||||
def get_targeted_guidance_loss(
|
||||
noisy_latents: torch.Tensor,
|
||||
@@ -183,35 +202,71 @@ def get_targeted_guidance_loss(
|
||||
dtype = get_torch_dtype(sd.torch_dtype)
|
||||
device = sd.device_torch
|
||||
|
||||
|
||||
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
||||
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
||||
|
||||
unconditional_diff = (unconditional_latents - conditional_latents)
|
||||
# apply random offset to both latents
|
||||
offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||
offset = offset * 0.1
|
||||
conditional_latents = conditional_latents + offset
|
||||
unconditional_latents = unconditional_latents + offset
|
||||
|
||||
# get random scale 0f 0.8 to 1.2
|
||||
scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||
scale = scale * 0.4
|
||||
scale = scale + 0.8
|
||||
conditional_latents = conditional_latents * scale
|
||||
unconditional_latents = unconditional_latents * scale
|
||||
|
||||
diff_mask = get_differential_mask(
|
||||
conditional_latents,
|
||||
unconditional_latents,
|
||||
threshold=0.1
|
||||
threshold=0.2,
|
||||
gradient=True
|
||||
)
|
||||
|
||||
# this is a magic number I spent weeks deducing. It works and I have no idea why.
|
||||
# unconditional_diff_noise = unconditional_diff * DIFFERENTIAL_SCALER
|
||||
# standardize inpute to std of 1
|
||||
# combo_std = torch.cat([conditional_latents, unconditional_latents], dim=1).std(dim=[1, 2, 3], keepdim=True)
|
||||
#
|
||||
# # scale the latents to std of 1
|
||||
# conditional_latents = conditional_latents / combo_std
|
||||
# unconditional_latents = unconditional_latents / combo_std
|
||||
|
||||
inputs_abs_mean = torch.abs(conditional_latents).mean(dim=[1, 2, 3], keepdim=True)
|
||||
noise_abs_mean = torch.abs(noise).mean(dim=[1, 2, 3], keepdim=True)
|
||||
diff_noise_scaler = noise_abs_mean / inputs_abs_mean
|
||||
unconditional_diff_noise = unconditional_diff * diff_noise_scaler
|
||||
unconditional_diff = (unconditional_latents - conditional_latents)
|
||||
|
||||
|
||||
|
||||
# get a -0.5 to 0.5 multiplier for the diff noise
|
||||
# noise_multiplier = torch.rand((unconditional_diff.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||
# noise_multiplier = noise_multiplier - 0.5
|
||||
noise_multiplier = 1.0
|
||||
|
||||
# unconditional_diff_noise = unconditional_diff * noise_multiplier
|
||||
unconditional_diff_noise = unconditional_diff * noise_multiplier
|
||||
|
||||
# scale it to the timestep
|
||||
unconditional_diff_noise = sd.add_noise(
|
||||
torch.zeros_like(unconditional_latents),
|
||||
unconditional_diff_noise,
|
||||
timesteps
|
||||
)
|
||||
|
||||
unconditional_diff_noise = unconditional_diff_noise * 0.2
|
||||
|
||||
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
|
||||
|
||||
baseline_noisy_latents = sd.add_noise(
|
||||
conditional_latents,
|
||||
unconditional_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
target_noise = noise + unconditional_diff_noise
|
||||
noisy_latents = sd.add_noise(
|
||||
conditional_latents,
|
||||
# noise + unconditional_diff_noise,
|
||||
noise,
|
||||
target_noise,
|
||||
# noise,
|
||||
timesteps
|
||||
).detach()
|
||||
# Disable the LoRA network so we can predict parent network knowledge without it
|
||||
@@ -226,15 +281,20 @@ def get_targeted_guidance_loss(
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
).detach()
|
||||
).detach().requires_grad_(False)
|
||||
|
||||
# determine the error for the baseline prediction
|
||||
baseline_prediction_error = baseline_prediction - noise
|
||||
|
||||
|
||||
# turn the LoRA network back on.
|
||||
sd.unet.train()
|
||||
sd.network.is_active = True
|
||||
|
||||
sd.network.multiplier = network_weight_list
|
||||
|
||||
unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask)
|
||||
masked_noise = noise * diff_mask
|
||||
# unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask)
|
||||
# masked_noise = noise * diff_mask
|
||||
# pred_target = unmasked_noise + unconditional_diff_noise
|
||||
|
||||
# do our prediction with LoRA active on the scaled guidance latents
|
||||
@@ -246,30 +306,32 @@ def get_targeted_guidance_loss(
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
|
||||
prediction = prediction - unmasked_baseline_prediction
|
||||
# prediction = prediction - baseline_prediction
|
||||
|
||||
baseline_loss = torch.nn.functional.mse_loss(
|
||||
baseline_prediction.float(),
|
||||
noise.float(),
|
||||
|
||||
baselined_prediction = prediction - baseline_prediction
|
||||
|
||||
guidance_loss = torch.nn.functional.mse_loss(
|
||||
baselined_prediction.float(),
|
||||
# unconditional_diff_noise.float(),
|
||||
unconditional_diff_noise.float(),
|
||||
reduction="none"
|
||||
)
|
||||
baseline_loss = baseline_loss * (1.0 - diff_mask)
|
||||
baseline_loss = baseline_loss.mean([1, 2, 3])
|
||||
guidance_loss = guidance_loss.mean([1, 2, 3])
|
||||
|
||||
# loss = torch.nn.functional.l1_loss(
|
||||
loss = torch.nn.functional.mse_loss(
|
||||
guidance_loss = guidance_loss.mean()
|
||||
|
||||
|
||||
# do the masked noise prediction
|
||||
masked_noise_loss = torch.nn.functional.mse_loss(
|
||||
prediction.float(),
|
||||
masked_noise.float(),
|
||||
target_noise.float(),
|
||||
reduction="none"
|
||||
)
|
||||
loss = loss * diff_mask
|
||||
loss = loss.mean([1, 2, 3])
|
||||
primary_loss_scaler = 1.0
|
||||
) * diff_mask
|
||||
masked_noise_loss = masked_noise_loss.mean([1, 2, 3])
|
||||
masked_noise_loss = masked_noise_loss.mean()
|
||||
|
||||
loss = (loss * primary_loss_scaler) + baseline_loss
|
||||
|
||||
loss = loss.mean()
|
||||
loss = guidance_loss + masked_noise_loss
|
||||
|
||||
loss.backward()
|
||||
|
||||
@@ -280,8 +342,158 @@ def get_targeted_guidance_loss(
|
||||
return loss
|
||||
|
||||
|
||||
def get_targeted_guidance_loss_WIP(
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: 'PromptEmbeds',
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
sd: 'StableDiffusion',
|
||||
**kwargs
|
||||
):
|
||||
with torch.no_grad():
|
||||
# Perform targeted guidance (working title)
|
||||
dtype = get_torch_dtype(sd.torch_dtype)
|
||||
device = sd.device_torch
|
||||
|
||||
|
||||
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
||||
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
||||
|
||||
# apply random offset to both latents
|
||||
offset = torch.randn((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||
offset = offset * 0.1
|
||||
conditional_latents = conditional_latents + offset
|
||||
unconditional_latents = unconditional_latents + offset
|
||||
|
||||
# get random scale 0f 0.8 to 1.2
|
||||
scale = torch.rand((conditional_latents.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||
scale = scale * 0.4
|
||||
scale = scale + 0.8
|
||||
conditional_latents = conditional_latents * scale
|
||||
unconditional_latents = unconditional_latents * scale
|
||||
|
||||
diff_mask = get_differential_mask(
|
||||
conditional_latents,
|
||||
unconditional_latents,
|
||||
threshold=0.2,
|
||||
gradient=True
|
||||
)
|
||||
|
||||
# standardize inpute to std of 1
|
||||
# combo_std = torch.cat([conditional_latents, unconditional_latents], dim=1).std(dim=[1, 2, 3], keepdim=True)
|
||||
#
|
||||
# # scale the latents to std of 1
|
||||
# conditional_latents = conditional_latents / combo_std
|
||||
# unconditional_latents = unconditional_latents / combo_std
|
||||
|
||||
unconditional_diff = (unconditional_latents - conditional_latents)
|
||||
|
||||
|
||||
|
||||
# get a -0.5 to 0.5 multiplier for the diff noise
|
||||
# noise_multiplier = torch.rand((unconditional_diff.shape[0], 1, 1, 1), device=device, dtype=dtype)
|
||||
# noise_multiplier = noise_multiplier - 0.5
|
||||
noise_multiplier = 1.0
|
||||
|
||||
# unconditional_diff_noise = unconditional_diff * noise_multiplier
|
||||
unconditional_diff_noise = unconditional_diff * noise_multiplier
|
||||
|
||||
# scale it to the timestep
|
||||
unconditional_diff_noise = sd.add_noise(
|
||||
torch.zeros_like(unconditional_latents),
|
||||
unconditional_diff_noise,
|
||||
timesteps
|
||||
)
|
||||
|
||||
unconditional_diff_noise = unconditional_diff_noise * 0.2
|
||||
|
||||
unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
|
||||
|
||||
baseline_noisy_latents = sd.add_noise(
|
||||
unconditional_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
target_noise = noise + unconditional_diff_noise
|
||||
noisy_latents = sd.add_noise(
|
||||
conditional_latents,
|
||||
target_noise,
|
||||
# noise,
|
||||
timesteps
|
||||
).detach()
|
||||
# Disable the LoRA network so we can predict parent network knowledge without it
|
||||
sd.network.is_active = False
|
||||
sd.unet.eval()
|
||||
|
||||
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
||||
# This acts as our control to preserve the unaltered parts of the image.
|
||||
baseline_prediction = sd.predict_noise(
|
||||
latents=baseline_noisy_latents.to(device, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
).detach().requires_grad_(False)
|
||||
|
||||
# turn the LoRA network back on.
|
||||
sd.unet.train()
|
||||
sd.network.is_active = True
|
||||
|
||||
sd.network.multiplier = network_weight_list
|
||||
|
||||
# unmasked_baseline_prediction = baseline_prediction * (1.0 - diff_mask)
|
||||
# masked_noise = noise * diff_mask
|
||||
# pred_target = unmasked_noise + unconditional_diff_noise
|
||||
|
||||
# do our prediction with LoRA active on the scaled guidance latents
|
||||
prediction = sd.predict_noise(
|
||||
latents=noisy_latents.to(device, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
|
||||
|
||||
|
||||
baselined_prediction = prediction - baseline_prediction
|
||||
|
||||
guidance_loss = torch.nn.functional.mse_loss(
|
||||
baselined_prediction.float(),
|
||||
# unconditional_diff_noise.float(),
|
||||
unconditional_diff_noise.float(),
|
||||
reduction="none"
|
||||
)
|
||||
guidance_loss = guidance_loss.mean([1, 2, 3])
|
||||
|
||||
guidance_loss = guidance_loss.mean()
|
||||
|
||||
|
||||
# do the masked noise prediction
|
||||
masked_noise_loss = torch.nn.functional.mse_loss(
|
||||
prediction.float(),
|
||||
target_noise.float(),
|
||||
reduction="none"
|
||||
) * diff_mask
|
||||
masked_noise_loss = masked_noise_loss.mean([1, 2, 3])
|
||||
masked_noise_loss = masked_noise_loss.mean()
|
||||
|
||||
|
||||
loss = guidance_loss + masked_noise_loss
|
||||
|
||||
loss.backward()
|
||||
|
||||
# detach it so parent class can run backward on no grads without throwing error
|
||||
loss = loss.detach()
|
||||
loss.requires_grad_(True)
|
||||
|
||||
return loss
|
||||
|
||||
def get_guided_loss_polarity(
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
|
||||
@@ -13,6 +13,7 @@ from toolkit.config_modules import NetworkConfig
|
||||
from toolkit.lorm import extract_conv, extract_linear, count_parameters
|
||||
from toolkit.metadata import add_model_hash_to_meta
|
||||
from toolkit.paths import KEYMAPS_ROOT
|
||||
from toolkit.saving import get_lora_keymap_from_model_keymap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule
|
||||
@@ -338,6 +339,7 @@ class ToolkitNetworkMixin:
|
||||
train_unet: Optional[bool] = True,
|
||||
is_sdxl=False,
|
||||
is_v2=False,
|
||||
is_ssd=False,
|
||||
network_config: Optional[NetworkConfig] = None,
|
||||
is_lorm=False,
|
||||
**kwargs
|
||||
@@ -348,6 +350,7 @@ class ToolkitNetworkMixin:
|
||||
self._multiplier: float = 1.0
|
||||
self.is_active: bool = False
|
||||
self.is_sdxl = is_sdxl
|
||||
self.is_ssd = is_ssd
|
||||
self.is_v2 = is_v2
|
||||
self.is_merged_in = False
|
||||
self.is_lorm = is_lorm
|
||||
@@ -357,14 +360,25 @@ class ToolkitNetworkMixin:
|
||||
self.can_merge_in = not is_lorm
|
||||
|
||||
def get_keymap(self: Network):
|
||||
if self.is_sdxl:
|
||||
use_weight_mapping = False
|
||||
|
||||
if self.is_ssd:
|
||||
keymap_tail = 'ssd'
|
||||
use_weight_mapping = True
|
||||
elif self.is_sdxl:
|
||||
keymap_tail = 'sdxl'
|
||||
elif self.is_v2:
|
||||
keymap_tail = 'sd2'
|
||||
else:
|
||||
keymap_tail = 'sd1'
|
||||
# todo double check this
|
||||
use_weight_mapping = True
|
||||
|
||||
# load keymap
|
||||
keymap_name = f"stable_diffusion_locon_{keymap_tail}.json"
|
||||
if use_weight_mapping:
|
||||
keymap_name = f"stable_diffusion_{keymap_tail}.json"
|
||||
|
||||
keymap_path = os.path.join(KEYMAPS_ROOT, keymap_name)
|
||||
|
||||
keymap = None
|
||||
@@ -373,6 +387,10 @@ class ToolkitNetworkMixin:
|
||||
with open(keymap_path, 'r') as f:
|
||||
keymap = json.load(f)['ldm_diffusers_keymap']
|
||||
|
||||
if use_weight_mapping and keymap is not None:
|
||||
# get keymap from weights
|
||||
keymap = get_lora_keymap_from_model_keymap(keymap)
|
||||
|
||||
return keymap
|
||||
|
||||
def save_weights(
|
||||
|
||||
@@ -206,6 +206,7 @@ def load_t2i_model(
|
||||
|
||||
IP_ADAPTER_MODULES = ['image_proj', 'ip_adapter']
|
||||
|
||||
|
||||
def save_ip_adapter_from_diffusers(
|
||||
combined_state_dict: 'OrderedDict',
|
||||
output_file: str,
|
||||
@@ -241,3 +242,58 @@ def load_ip_adapter_model(
|
||||
return combined_state_dict
|
||||
else:
|
||||
return torch.load(path_to_file, map_location=device)
|
||||
|
||||
|
||||
def get_lora_keymap_from_model_keymap(model_keymap: 'OrderedDict') -> 'OrderedDict':
|
||||
lora_keymap = OrderedDict()
|
||||
|
||||
# see if we have dual text encoders " a key that starts with conditioner.embedders.1
|
||||
has_dual_text_encoders = False
|
||||
for key in model_keymap:
|
||||
if key.startswith('conditioner.embedders.1'):
|
||||
has_dual_text_encoders = True
|
||||
break
|
||||
|
||||
# map through the keys and values
|
||||
for key, value in model_keymap.items():
|
||||
# ignore bias weights
|
||||
if key.endswith('bias'):
|
||||
continue
|
||||
if key.endswith('.weight'):
|
||||
# remove the .weight
|
||||
key = key[:-7]
|
||||
if value.endswith(".weight"):
|
||||
# remove the .weight
|
||||
value = value[:-7]
|
||||
|
||||
# unet for all
|
||||
key = key.replace('model.diffusion_model', 'lora_unet')
|
||||
if value.startswith('unet'):
|
||||
value = f"lora_{value}"
|
||||
|
||||
# text encoder
|
||||
if has_dual_text_encoders:
|
||||
key = key.replace('conditioner.embedders.0', 'lora_te1')
|
||||
key = key.replace('conditioner.embedders.1', 'lora_te2')
|
||||
if value.startswith('te0') or value.startswith('te1'):
|
||||
value = f"lora_{value}"
|
||||
value.replace('lora_te1', 'lora_te2')
|
||||
value.replace('lora_te0', 'lora_te1')
|
||||
|
||||
key = key.replace('cond_stage_model.transformer', 'lora_te')
|
||||
|
||||
if value.startswith('te_'):
|
||||
value = f"lora_{value}"
|
||||
|
||||
# replace periods with underscores
|
||||
key = key.replace('.', '_')
|
||||
value = value.replace('.', '_')
|
||||
|
||||
# add all the weights
|
||||
lora_keymap[f"{key}.lora_down.weight"] = f"{value}.lora_down.weight"
|
||||
lora_keymap[f"{key}.lora_down.bias"] = f"{value}.lora_down.bias"
|
||||
lora_keymap[f"{key}.lora_up.weight"] = f"{value}.lora_up.weight"
|
||||
lora_keymap[f"{key}.lora_up.bias"] = f"{value}.lora_up.bias"
|
||||
lora_keymap[f"{key}.alpha"] = f"{value}.alpha"
|
||||
|
||||
return lora_keymap
|
||||
|
||||
@@ -35,7 +35,7 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
|
||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
|
||||
StableDiffusionXLImg2ImgPipeline
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler
|
||||
import diffusers
|
||||
from diffusers import \
|
||||
AutoencoderKL, \
|
||||
@@ -279,6 +279,20 @@ class StableDiffusion:
|
||||
self.load_refiner()
|
||||
self.is_loaded = True
|
||||
|
||||
def te_train(self):
|
||||
if isinstance(self.text_encoder, list):
|
||||
for te in self.text_encoder:
|
||||
te.train()
|
||||
else:
|
||||
self.text_encoder.train()
|
||||
|
||||
def te_eval(self):
|
||||
if isinstance(self.text_encoder, list):
|
||||
for te in self.text_encoder:
|
||||
te.eval()
|
||||
else:
|
||||
self.text_encoder.eval()
|
||||
|
||||
def load_refiner(self):
|
||||
# for now, we are just going to rely on the TE from the base model
|
||||
# which is TE2 for SDXL and TE for SD (no refiner currently)
|
||||
@@ -721,6 +735,7 @@ class StableDiffusion:
|
||||
add_time_ids=None,
|
||||
conditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||
unconditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||
is_input_scaled=True,
|
||||
**kwargs,
|
||||
):
|
||||
with torch.no_grad():
|
||||
@@ -764,6 +779,8 @@ class StableDiffusion:
|
||||
|
||||
|
||||
def scale_model_input(model_input, timestep_tensor):
|
||||
if is_input_scaled:
|
||||
return model_input
|
||||
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
||||
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
||||
out_chunks = []
|
||||
@@ -859,7 +876,7 @@ class StableDiffusion:
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
@@ -903,7 +920,7 @@ class StableDiffusion:
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
**kwargs,
|
||||
@@ -924,6 +941,15 @@ class StableDiffusion:
|
||||
return noise_pred
|
||||
|
||||
def step_scheduler(self, model_input, latent_input, timestep_tensor):
|
||||
# // sometimes they are on the wrong device, no idea why
|
||||
if isinstance(self.noise_scheduler, DDPMScheduler) or isinstance(self.noise_scheduler, LCMScheduler):
|
||||
try:
|
||||
self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
|
||||
self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch)
|
||||
self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
||||
latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0)
|
||||
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
||||
@@ -955,10 +981,12 @@ class StableDiffusion:
|
||||
add_time_ids=None,
|
||||
bleed_ratio: float = 0.5,
|
||||
bleed_latents: torch.FloatTensor = None,
|
||||
is_input_scaled=False,
|
||||
**kwargs,
|
||||
):
|
||||
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
|
||||
|
||||
|
||||
for timestep in tqdm(timesteps_to_run, leave=False):
|
||||
timestep = timestep.unsqueeze_(0)
|
||||
noise_pred = self.predict_noise(
|
||||
@@ -967,6 +995,7 @@ class StableDiffusion:
|
||||
timestep,
|
||||
guidance_scale=guidance_scale,
|
||||
add_time_ids=add_time_ids,
|
||||
is_input_scaled=is_input_scaled,
|
||||
**kwargs,
|
||||
)
|
||||
# some schedulers need to run separately, so do that. (euler for example)
|
||||
@@ -977,6 +1006,9 @@ class StableDiffusion:
|
||||
if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]:
|
||||
latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio)
|
||||
|
||||
# only skip first scaling
|
||||
is_input_scaled = False
|
||||
|
||||
# return latents_steps
|
||||
return latents
|
||||
|
||||
|
||||
Reference in New Issue
Block a user