mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Various bug fixes and improvements
This commit is contained in:
@@ -2,11 +2,12 @@ import copy
|
|||||||
import random
|
import random
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import os
|
import os
|
||||||
|
from contextlib import nullcontext
|
||||||
from typing import Optional, Union, List
|
from typing import Optional, Union, List
|
||||||
from torch.utils.data import ConcatDataset, DataLoader
|
from torch.utils.data import ConcatDataset, DataLoader
|
||||||
from toolkit.data_loader import PairedImageDataset
|
from toolkit.data_loader import PairedImageDataset
|
||||||
from toolkit.prompt_utils import concat_prompt_embeds
|
from toolkit.prompt_utils import concat_prompt_embeds
|
||||||
from toolkit.stable_diffusion_model import StableDiffusion
|
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
|
||||||
from toolkit.train_tools import get_torch_dtype
|
from toolkit.train_tools import get_torch_dtype
|
||||||
import gc
|
import gc
|
||||||
from toolkit import train_tools
|
from toolkit import train_tools
|
||||||
@@ -80,34 +81,16 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
|||||||
imgs, prompts = batch
|
imgs, prompts = batch
|
||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype)
|
imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
# split batched images in half so left is negative and right is positive
|
# split batched images in half so left is negative and right is positive
|
||||||
negative_images, positive_images = torch.chunk(imgs, 2, dim=3)
|
negative_images, positive_images = torch.chunk(imgs, 2, dim=3)
|
||||||
|
|
||||||
|
positive_latents = self.sd.encode_images(positive_images)
|
||||||
|
negative_latents = self.sd.encode_images(negative_images)
|
||||||
|
|
||||||
height = positive_images.shape[2]
|
height = positive_images.shape[2]
|
||||||
width = positive_images.shape[3]
|
width = positive_images.shape[3]
|
||||||
batch_size = positive_images.shape[0]
|
batch_size = positive_images.shape[0]
|
||||||
|
|
||||||
# encode the images
|
|
||||||
positive_latents = self.sd.vae.encode(positive_images).latent_dist.sample()
|
|
||||||
positive_latents = positive_latents * 0.18215
|
|
||||||
negative_latents = self.sd.vae.encode(negative_images).latent_dist.sample()
|
|
||||||
negative_latents = negative_latents * 0.18215
|
|
||||||
|
|
||||||
embedding_list = []
|
|
||||||
negative_embedding_list = []
|
|
||||||
# embed the prompts
|
|
||||||
for prompt in prompts:
|
|
||||||
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
|
|
||||||
embedding_list.append(embedding)
|
|
||||||
# just empty for now
|
|
||||||
# todo cache this?
|
|
||||||
negative_embed = self.sd.encode_prompt('').to(self.device_torch, dtype=dtype)
|
|
||||||
negative_embedding_list.append(negative_embed)
|
|
||||||
|
|
||||||
conditional_embeds = concat_prompt_embeds(embedding_list)
|
|
||||||
unconditional_embeds = concat_prompt_embeds(negative_embedding_list)
|
|
||||||
|
|
||||||
if self.train_config.gradient_checkpointing:
|
if self.train_config.gradient_checkpointing:
|
||||||
# may get disabled elsewhere
|
# may get disabled elsewhere
|
||||||
self.sd.unet.enable_gradient_checkpointing()
|
self.sd.unet.enable_gradient_checkpointing()
|
||||||
@@ -115,26 +98,12 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
|||||||
noise_scheduler = self.sd.noise_scheduler
|
noise_scheduler = self.sd.noise_scheduler
|
||||||
optimizer = self.optimizer
|
optimizer = self.optimizer
|
||||||
lr_scheduler = self.lr_scheduler
|
lr_scheduler = self.lr_scheduler
|
||||||
loss_function = torch.nn.MSELoss()
|
|
||||||
|
|
||||||
def get_noise_pred(neg, pos, gs, cts, dn):
|
|
||||||
return self.sd.predict_noise(
|
|
||||||
latents=dn,
|
|
||||||
text_embeddings=train_tools.concat_prompt_embeddings(
|
|
||||||
neg, # negative prompt
|
|
||||||
pos, # positive prompt
|
|
||||||
self.train_config.batch_size,
|
|
||||||
),
|
|
||||||
timestep=cts,
|
|
||||||
guidance_scale=gs,
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
self.sd.noise_scheduler.set_timesteps(
|
self.sd.noise_scheduler.set_timesteps(
|
||||||
self.train_config.max_denoising_steps, device=self.device_torch
|
self.train_config.max_denoising_steps, device=self.device_torch
|
||||||
)
|
)
|
||||||
|
|
||||||
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
|
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch)
|
||||||
timesteps = timesteps.long()
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
# get noise
|
# get noise
|
||||||
@@ -147,6 +116,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
|||||||
|
|
||||||
if do_mirror_loss:
|
if do_mirror_loss:
|
||||||
# mirror the noise
|
# mirror the noise
|
||||||
|
# torch shape is [batch, channels, height, width]
|
||||||
noise_negative = torch.flip(noise_positive.clone(), dims=[3])
|
noise_negative = torch.flip(noise_positive.clone(), dims=[3])
|
||||||
else:
|
else:
|
||||||
noise_negative = noise_positive.clone()
|
noise_negative = noise_positive.clone()
|
||||||
@@ -159,8 +129,6 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
|||||||
noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0)
|
noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0)
|
||||||
noise = torch.cat([noise_positive, noise_negative], dim=0)
|
noise = torch.cat([noise_positive, noise_negative], dim=0)
|
||||||
timesteps = torch.cat([timesteps, timesteps], dim=0)
|
timesteps = torch.cat([timesteps, timesteps], dim=0)
|
||||||
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
|
||||||
unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds])
|
|
||||||
network_multiplier = [1.0, -1.0]
|
network_multiplier = [1.0, -1.0]
|
||||||
|
|
||||||
flush()
|
flush()
|
||||||
@@ -170,22 +138,31 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
|||||||
loss_mirror_float = None
|
loss_mirror_float = None
|
||||||
|
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
noisy_latents.requires_grad = False
|
||||||
|
|
||||||
|
# if training text encoder enable grads, else do context of no grad
|
||||||
|
with torch.set_grad_enabled(self.train_config.train_text_encoder):
|
||||||
|
# text encoding
|
||||||
|
embedding_list = []
|
||||||
|
# embed the prompts
|
||||||
|
for prompt in prompts:
|
||||||
|
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
|
||||||
|
embedding_list.append(embedding)
|
||||||
|
conditional_embeds = concat_prompt_embeds(embedding_list)
|
||||||
|
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||||
|
|
||||||
with self.network:
|
with self.network:
|
||||||
assert self.network.is_active
|
assert self.network.is_active
|
||||||
loss_list = []
|
|
||||||
|
|
||||||
# do positive first
|
|
||||||
self.network.multiplier = network_multiplier
|
self.network.multiplier = network_multiplier
|
||||||
|
|
||||||
noise_pred = get_noise_pred(
|
noise_pred = self.sd.predict_noise(
|
||||||
unconditional_embeds,
|
latents=noisy_latents,
|
||||||
conditional_embeds,
|
conditional_embeddings=conditional_embeds,
|
||||||
1,
|
timestep=timesteps,
|
||||||
timesteps,
|
|
||||||
noisy_latents
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.sd.is_v2: # check is vpred, don't want to track it down right now
|
if self.sd.prediction_type == 'v_prediction':
|
||||||
# v-parameterization training
|
# v-parameterization training
|
||||||
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||||
else:
|
else:
|
||||||
@@ -199,7 +176,6 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
|||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
loss_slide_float = loss.item()
|
loss_slide_float = loss.item()
|
||||||
|
|
||||||
|
|
||||||
if do_mirror_loss:
|
if do_mirror_loss:
|
||||||
noise_pred_pos, noise_pred_neg = torch.chunk(noise_pred, 2, dim=0)
|
noise_pred_pos, noise_pred_neg = torch.chunk(noise_pred, 2, dim=0)
|
||||||
# mirror the negative
|
# mirror the negative
|
||||||
@@ -221,7 +197,6 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
|
||||||
# reset network
|
# reset network
|
||||||
self.network.multiplier = 1.0
|
self.network.multiplier = 1.0
|
||||||
|
|
||||||
|
|||||||
@@ -9,17 +9,21 @@ config:
|
|||||||
# for tensorboard logging
|
# for tensorboard logging
|
||||||
log_dir: "/home/jaret/Dev/.tensorboard"
|
log_dir: "/home/jaret/Dev/.tensorboard"
|
||||||
network:
|
network:
|
||||||
type: "lierla" # lierla is traditional LoRA that works everywhere, only linear layers
|
type: "lora"
|
||||||
rank: 16
|
linear: 64
|
||||||
alpha: 8
|
linear_alpha: 32
|
||||||
|
conv: 32
|
||||||
|
conv_alpha: 16
|
||||||
train:
|
train:
|
||||||
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
||||||
steps: 1000
|
steps: 5000
|
||||||
lr: 5e-5
|
lr: 1e-4
|
||||||
train_unet: true
|
train_unet: true
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
train_text_encoder: false
|
train_text_encoder: true
|
||||||
optimizer: "lion8bit"
|
optimizer: "adamw"
|
||||||
|
optimizer_params:
|
||||||
|
weight_decay: 1e-2
|
||||||
lr_scheduler: "constant"
|
lr_scheduler: "constant"
|
||||||
max_denoising_steps: 1000
|
max_denoising_steps: 1000
|
||||||
batch_size: 1
|
batch_size: 1
|
||||||
@@ -36,11 +40,11 @@ config:
|
|||||||
is_v_pred: false # for v-prediction models (most v2 models)
|
is_v_pred: false # for v-prediction models (most v2 models)
|
||||||
save:
|
save:
|
||||||
dtype: float16 # precision to save
|
dtype: float16 # precision to save
|
||||||
save_every: 100 # save every this many steps
|
save_every: 1000 # save every this many steps
|
||||||
max_step_saves_to_keep: 2 # only affects step counts
|
max_step_saves_to_keep: 2 # only affects step counts
|
||||||
sample:
|
sample:
|
||||||
sampler: "ddpm" # must match train.noise_scheduler
|
sampler: "ddpm" # must match train.noise_scheduler
|
||||||
sample_every: 20 # sample every this many steps
|
sample_every: 100 # sample every this many steps
|
||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
prompts:
|
prompts:
|
||||||
@@ -81,6 +85,8 @@ config:
|
|||||||
- 512
|
- 512
|
||||||
slider_pair_folder: "/mnt/Datasets/stable-diffusion/slider_reference/subject_turner"
|
slider_pair_folder: "/mnt/Datasets/stable-diffusion/slider_reference/subject_turner"
|
||||||
target_class: "photo of a person"
|
target_class: "photo of a person"
|
||||||
|
# additional_losses:
|
||||||
|
# - "mirror"
|
||||||
|
|
||||||
|
|
||||||
meta:
|
meta:
|
||||||
|
|||||||
@@ -97,21 +97,25 @@ class LoRAModule(torch.nn.Module):
|
|||||||
if len(self.multiplier) == 0:
|
if len(self.multiplier) == 0:
|
||||||
# single item, just return it
|
# single item, just return it
|
||||||
return self.multiplier[0]
|
return self.multiplier[0]
|
||||||
|
elif len(self.multiplier) == batch_size:
|
||||||
|
# not doing CFG
|
||||||
|
multiplier_tensor = torch.tensor(self.multiplier).to(lora_up.device, dtype=lora_up.dtype)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
# we have a list of multipliers, so we need to get the multiplier for this batch
|
# we have a list of multipliers, so we need to get the multiplier for this batch
|
||||||
multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype)
|
multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype)
|
||||||
# should be 1 for if total batch size was 1
|
# should be 1 for if total batch size was 1
|
||||||
num_interleaves = (batch_size // 2) // len(self.multiplier)
|
num_interleaves = (batch_size // 2) // len(self.multiplier)
|
||||||
multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves)
|
multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves)
|
||||||
|
|
||||||
# match lora_up rank
|
# match lora_up rank
|
||||||
if len(lora_up.size()) == 2:
|
if len(lora_up.size()) == 2:
|
||||||
multiplier_tensor = multiplier_tensor.view(-1, 1)
|
multiplier_tensor = multiplier_tensor.view(-1, 1)
|
||||||
elif len(lora_up.size()) == 3:
|
elif len(lora_up.size()) == 3:
|
||||||
multiplier_tensor = multiplier_tensor.view(-1, 1, 1)
|
multiplier_tensor = multiplier_tensor.view(-1, 1, 1)
|
||||||
elif len(lora_up.size()) == 4:
|
elif len(lora_up.size()) == 4:
|
||||||
multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1)
|
multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1)
|
||||||
return multiplier_tensor
|
return multiplier_tensor
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return self.multiplier
|
return self.multiplier
|
||||||
|
|||||||
@@ -7,9 +7,11 @@ import os
|
|||||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from torchvision.transforms import Resize
|
||||||
|
|
||||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||||
convert_vae_state_dict
|
convert_vae_state_dict
|
||||||
|
from toolkit import train_tools
|
||||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||||
from toolkit.metadata import get_meta_for_safetensors
|
from toolkit.metadata import get_meta_for_safetensors
|
||||||
from toolkit.paths import REPOS_ROOT
|
from toolkit.paths import REPOS_ROOT
|
||||||
@@ -180,6 +182,7 @@ class StableDiffusion:
|
|||||||
device=self.device_torch,
|
device=self.device_torch,
|
||||||
load_safety_checker=False,
|
load_safety_checker=False,
|
||||||
requires_safety_checker=False,
|
requires_safety_checker=False,
|
||||||
|
safety_checker=False
|
||||||
).to(self.device_torch)
|
).to(self.device_torch)
|
||||||
else:
|
else:
|
||||||
pipe = pipln.from_single_file(
|
pipe = pipln.from_single_file(
|
||||||
@@ -189,7 +192,9 @@ class StableDiffusion:
|
|||||||
device=self.device_torch,
|
device=self.device_torch,
|
||||||
load_safety_checker=False,
|
load_safety_checker=False,
|
||||||
requires_safety_checker=False,
|
requires_safety_checker=False,
|
||||||
|
safety_checker=False
|
||||||
).to(self.device_torch)
|
).to(self.device_torch)
|
||||||
|
|
||||||
pipe.register_to_config(requires_safety_checker=False)
|
pipe.register_to_config(requires_safety_checker=False)
|
||||||
text_encoder = pipe.text_encoder
|
text_encoder = pipe.text_encoder
|
||||||
text_encoder.to(self.device_torch, dtype=dtype)
|
text_encoder.to(self.device_torch, dtype=dtype)
|
||||||
@@ -379,28 +384,60 @@ class StableDiffusion:
|
|||||||
dynamic_crops=False, # look into this
|
dynamic_crops=False, # look into this
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
).to(self.device_torch, dtype=dtype)
|
).to(self.device_torch, dtype=dtype)
|
||||||
return train_util.concat_embeddings(
|
return prompt_ids
|
||||||
prompt_ids, prompt_ids, bs
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def predict_noise(
|
def predict_noise(
|
||||||
self,
|
self,
|
||||||
latents: torch.FloatTensor,
|
latents: torch.Tensor,
|
||||||
text_embeddings: PromptEmbeds,
|
text_embeddings: Union[PromptEmbeds, None] = None,
|
||||||
timestep: int,
|
timestep: Union[int, torch.Tensor] = 1,
|
||||||
guidance_scale=7.5,
|
guidance_scale=7.5,
|
||||||
guidance_rescale=0, # 0.7
|
guidance_rescale=0, # 0.7 sdxl
|
||||||
add_time_ids=None,
|
add_time_ids=None,
|
||||||
|
conditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||||
|
unconditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
# get the embeddings
|
||||||
|
if text_embeddings is None and conditional_embeddings is None:
|
||||||
|
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
|
||||||
|
if text_embeddings is None and unconditional_embeddings is not None:
|
||||||
|
text_embeddings = train_tools.concat_prompt_embeddings(
|
||||||
|
unconditional_embeddings, # negative embedding
|
||||||
|
conditional_embeddings, # positive embedding
|
||||||
|
latents.shape[0], # batch size
|
||||||
|
)
|
||||||
|
elif text_embeddings is None and conditional_embeddings is not None:
|
||||||
|
# not doing cfg
|
||||||
|
text_embeddings = conditional_embeddings
|
||||||
|
|
||||||
|
# CFG is comparing neg and positive, if we have concatenated embeddings
|
||||||
|
# then we are doing it, otherwise we are not and takes half the time.
|
||||||
|
do_classifier_free_guidance = True
|
||||||
|
|
||||||
|
# check if batch size of embeddings matches batch size of latents
|
||||||
|
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
|
||||||
|
do_classifier_free_guidance = False
|
||||||
|
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
|
||||||
|
raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings")
|
||||||
|
|
||||||
if self.is_xl:
|
if self.is_xl:
|
||||||
if add_time_ids is None:
|
if add_time_ids is None:
|
||||||
add_time_ids = self.get_time_ids_from_latents(latents)
|
add_time_ids = self.get_time_ids_from_latents(latents)
|
||||||
|
|
||||||
latent_model_input = torch.cat([latents] * 2)
|
if do_classifier_free_guidance:
|
||||||
|
# todo check this with larget batches
|
||||||
|
train_util.concat_embeddings(
|
||||||
|
add_time_ids, add_time_ids, 1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# concat to fit batch size
|
||||||
|
add_time_ids = torch.cat([add_time_ids] * latents.shape[0])
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
latent_model_input = torch.cat([latents] * 2)
|
||||||
|
|
||||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||||
|
|
||||||
@@ -417,20 +454,24 @@ class StableDiffusion:
|
|||||||
added_cond_kwargs=added_cond_kwargs,
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
).sample
|
).sample
|
||||||
|
|
||||||
# perform guidance
|
if do_classifier_free_guidance:
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
# perform guidance
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
noise_pred_text - noise_pred_uncond
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||||
)
|
noise_pred_text - noise_pred_uncond
|
||||||
|
)
|
||||||
|
|
||||||
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
|
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
|
||||||
if guidance_rescale > 0.0:
|
if guidance_rescale > 0.0:
|
||||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# if we are doing classifier free guidance, need to double up
|
if do_classifier_free_guidance:
|
||||||
latent_model_input = torch.cat([latents] * 2)
|
# if we are doing classifier free guidance, need to double up
|
||||||
|
latent_model_input = torch.cat([latents] * 2)
|
||||||
|
else:
|
||||||
|
latent_model_input = latents
|
||||||
|
|
||||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||||
|
|
||||||
@@ -441,10 +482,12 @@ class StableDiffusion:
|
|||||||
encoder_hidden_states=text_embeddings.text_embeds,
|
encoder_hidden_states=text_embeddings.text_embeds,
|
||||||
).sample
|
).sample
|
||||||
|
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
if do_classifier_free_guidance:
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
# perform guidance
|
||||||
noise_pred_text - noise_pred_uncond
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
)
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||||
|
noise_pred_text - noise_pred_uncond
|
||||||
|
)
|
||||||
|
|
||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
@@ -495,14 +538,68 @@ class StableDiffusion:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def encode_images(
|
||||||
|
self,
|
||||||
|
image_list: List[torch.Tensor],
|
||||||
|
device=None,
|
||||||
|
dtype=None
|
||||||
|
):
|
||||||
|
if device is None:
|
||||||
|
device = self.device
|
||||||
|
if dtype is None:
|
||||||
|
dtype = self.torch_dtype
|
||||||
|
|
||||||
|
latent_list = []
|
||||||
|
# Move to vae to device if on cpu
|
||||||
|
if self.vae.device == 'cpu':
|
||||||
|
self.vae.to(self.device)
|
||||||
|
# move to device and dtype
|
||||||
|
image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list]
|
||||||
|
|
||||||
|
# resize images if not divisible by 8
|
||||||
|
for i in range(len(image_list)):
|
||||||
|
image = image_list[i]
|
||||||
|
if image.shape[1] % 8 != 0 or image.shape[2] % 8 != 0:
|
||||||
|
image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image)
|
||||||
|
|
||||||
|
images = torch.stack(image_list)
|
||||||
|
latents = self.vae.encode(images).latent_dist.sample()
|
||||||
|
latents = latents * 0.18215
|
||||||
|
latents = latents.to(device, dtype=dtype)
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def encode_image_prompt_pairs(
|
||||||
|
self,
|
||||||
|
prompt_list: List[str],
|
||||||
|
image_list: List[torch.Tensor],
|
||||||
|
device=None,
|
||||||
|
dtype=None
|
||||||
|
):
|
||||||
|
# todo check image types and expand and rescale as needed
|
||||||
|
# device and dtype are for outputs
|
||||||
|
if device is None:
|
||||||
|
device = self.device
|
||||||
|
if dtype is None:
|
||||||
|
dtype = self.torch_dtype
|
||||||
|
|
||||||
|
embedding_list = []
|
||||||
|
latent_list = []
|
||||||
|
# embed the prompts
|
||||||
|
for prompt in prompt_list:
|
||||||
|
embedding = self.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
|
||||||
|
embedding_list.append(embedding)
|
||||||
|
|
||||||
|
return embedding_list, latent_list
|
||||||
|
|
||||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
|
|
||||||
def update_sd(prefix, sd):
|
def update_sd(prefix, sd):
|
||||||
for k, v in sd.items():
|
for k, v in sd.items():
|
||||||
key = prefix + k
|
key = prefix + k
|
||||||
v = v.detach().clone().to("cpu").to(get_torch_dtype(save_dtype))
|
v = v.detach().clone()
|
||||||
state_dict[key] = v
|
state_dict[key] = v.to("cpu", dtype=get_torch_dtype(save_dtype))
|
||||||
|
|
||||||
# todo see what logit scale is
|
# todo see what logit scale is
|
||||||
if self.is_xl:
|
if self.is_xl:
|
||||||
@@ -536,4 +633,6 @@ class StableDiffusion:
|
|||||||
|
|
||||||
# prepare metadata
|
# prepare metadata
|
||||||
meta = get_meta_for_safetensors(meta)
|
meta = get_meta_for_safetensors(meta)
|
||||||
|
# make sure parent folder exists
|
||||||
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||||
save_file(state_dict, output_file, metadata=meta)
|
save_file(state_dict, output_file, metadata=meta)
|
||||||
|
|||||||
@@ -34,13 +34,16 @@ SCHEDLER_SCHEDULE = "scaled_linear"
|
|||||||
|
|
||||||
|
|
||||||
def get_torch_dtype(dtype_str):
|
def get_torch_dtype(dtype_str):
|
||||||
|
# if it is a torch dtype, return it
|
||||||
|
if isinstance(dtype_str, torch.dtype):
|
||||||
|
return dtype_str
|
||||||
if dtype_str == "float" or dtype_str == "fp32" or dtype_str == "single" or dtype_str == "float32":
|
if dtype_str == "float" or dtype_str == "fp32" or dtype_str == "single" or dtype_str == "float32":
|
||||||
return torch.float
|
return torch.float
|
||||||
if dtype_str == "fp16" or dtype_str == "half" or dtype_str == "float16":
|
if dtype_str == "fp16" or dtype_str == "half" or dtype_str == "float16":
|
||||||
return torch.float16
|
return torch.float16
|
||||||
if dtype_str == "bf16" or dtype_str == "bfloat16":
|
if dtype_str == "bf16" or dtype_str == "bfloat16":
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
return None
|
return dtype_str
|
||||||
|
|
||||||
|
|
||||||
def replace_filewords_prompt(prompt, args: argparse.Namespace):
|
def replace_filewords_prompt(prompt, args: argparse.Namespace):
|
||||||
|
|||||||
Reference in New Issue
Block a user