mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-24 06:13:56 +00:00
Moved SD batch processing to a shared method and added it for use in slider training. Still testing if it affects quality over sampling
This commit is contained in:
@@ -36,67 +36,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.sd.text_encoder.train()
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
with torch.no_grad():
|
||||
imgs, prompts, dataset_config = batch
|
||||
|
||||
# convert the 0 or 1 for is reg to a bool list
|
||||
is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])])
|
||||
if isinstance(is_reg_list, torch.Tensor):
|
||||
is_reg_list = is_reg_list.numpy().tolist()
|
||||
is_reg_list = [bool(x) for x in is_reg_list]
|
||||
|
||||
conditioned_prompts = []
|
||||
|
||||
for prompt, is_reg in zip(prompts, is_reg_list):
|
||||
|
||||
# make sure the embedding is in the prompts
|
||||
if self.embedding is not None:
|
||||
prompt = self.embedding.inject_embedding_to_prompt(
|
||||
prompt,
|
||||
expand_token=True,
|
||||
add_if_not_present=True,
|
||||
)
|
||||
|
||||
# make sure trigger is in the prompts if not a regularization run
|
||||
if self.trigger_word is not None and not is_reg:
|
||||
prompt = self.sd.inject_trigger_into_prompt(
|
||||
prompt,
|
||||
add_if_not_present=True,
|
||||
)
|
||||
conditioned_prompts.append(prompt)
|
||||
|
||||
batch_size = imgs.shape[0]
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
latents = self.sd.encode_images(imgs)
|
||||
|
||||
noise_scheduler = self.sd.noise_scheduler
|
||||
optimizer = self.optimizer
|
||||
lr_scheduler = self.lr_scheduler
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
)
|
||||
|
||||
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
pixel_height=imgs.shape[2],
|
||||
pixel_width=imgs.shape[3],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# remove grads for these
|
||||
noisy_latents.requires_grad = False
|
||||
noise.requires_grad = False
|
||||
|
||||
flush()
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
@@ -135,7 +76,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
@@ -144,7 +85,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -153,9 +94,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
flush()
|
||||
|
||||
# apply gradients
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
lr_scheduler.step()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.lr_scheduler.step()
|
||||
|
||||
if self.embedding is not None:
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
|
||||
@@ -283,6 +283,66 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else:
|
||||
print("load_weights not implemented for non-network models")
|
||||
|
||||
def process_general_training_batch(self, batch):
|
||||
with torch.no_grad():
|
||||
imgs, prompts, dataset_config = batch
|
||||
|
||||
# convert the 0 or 1 for is reg to a bool list
|
||||
is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])])
|
||||
if isinstance(is_reg_list, torch.Tensor):
|
||||
is_reg_list = is_reg_list.numpy().tolist()
|
||||
is_reg_list = [bool(x) for x in is_reg_list]
|
||||
|
||||
conditioned_prompts = []
|
||||
|
||||
for prompt, is_reg in zip(prompts, is_reg_list):
|
||||
|
||||
# make sure the embedding is in the prompts
|
||||
if self.embedding is not None:
|
||||
prompt = self.embedding.inject_embedding_to_prompt(
|
||||
prompt,
|
||||
expand_token=True,
|
||||
add_if_not_present=True,
|
||||
)
|
||||
|
||||
# make sure trigger is in the prompts if not a regularization run
|
||||
if self.trigger_word is not None and not is_reg:
|
||||
prompt = self.sd.inject_trigger_into_prompt(
|
||||
prompt,
|
||||
add_if_not_present=True,
|
||||
)
|
||||
conditioned_prompts.append(prompt)
|
||||
|
||||
batch_size = imgs.shape[0]
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
latents = self.sd.encode_images(imgs)
|
||||
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
)
|
||||
|
||||
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
pixel_height=imgs.shape[2],
|
||||
pixel_width=imgs.shape[3],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# remove grads for these
|
||||
noisy_latents.requires_grad = False
|
||||
noise.requires_grad = False
|
||||
|
||||
return noisy_latents, noise, timesteps, conditioned_prompts, imgs
|
||||
|
||||
def run(self):
|
||||
# run base process run
|
||||
BaseTrainProcess.run(self)
|
||||
|
||||
@@ -169,11 +169,16 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
self.prompt_pairs = prompt_pairs
|
||||
# self.anchor_pairs = anchor_pairs
|
||||
flush()
|
||||
if self.data_loader is not None:
|
||||
# we will have images, prep the vae
|
||||
self.sd.vae.eval()
|
||||
self.sd.vae.to(self.device_torch)
|
||||
# end hook_before_train_loop
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
|
||||
# get a random pair
|
||||
prompt_pair: EncodedPromptPair = self.prompt_pairs[
|
||||
torch.randint(0, len(self.prompt_pairs), (1,)).item()
|
||||
@@ -207,55 +212,67 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# ger a random number of steps
|
||||
timesteps_to = torch.randint(
|
||||
1, self.train_config.max_denoising_steps, (1,)
|
||||
).item()
|
||||
|
||||
# for a complete slider, the batch size is 4 to begin with now
|
||||
true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size
|
||||
from_batch = False
|
||||
if batch is not None:
|
||||
# traing from a batch of images, not generating ourselves
|
||||
from_batch = True
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
pixel_height=height,
|
||||
pixel_width=width,
|
||||
batch_size=true_batch_size,
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
denoised_latent_chunks = [noisy_latents] * self.prompt_chunk_size
|
||||
denoised_latents = torch.cat(denoised_latent_chunks, dim=0)
|
||||
current_timestep = timesteps
|
||||
else:
|
||||
|
||||
# get latents
|
||||
latents = noise * self.sd.noise_scheduler.init_noise_sigma
|
||||
latents = latents.to(self.device_torch, dtype=dtype)
|
||||
|
||||
with self.network:
|
||||
assert self.network.is_active
|
||||
# pass the multiplier list to the network
|
||||
self.network.multiplier = prompt_pair.multiplier_list
|
||||
denoised_latents = self.sd.diffuse_some_steps(
|
||||
latents, # pass simple noise latents
|
||||
train_tools.concat_prompt_embeddings(
|
||||
prompt_pair.positive_target, # unconditional
|
||||
prompt_pair.target_class, # target
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
start_timesteps=0,
|
||||
total_timesteps=timesteps_to,
|
||||
guidance_scale=3,
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
)
|
||||
|
||||
# split the latents into out prompt pair chunks
|
||||
denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
noise_scheduler.set_timesteps(1000)
|
||||
# ger a random number of steps
|
||||
timesteps_to = torch.randint(
|
||||
1, self.train_config.max_denoising_steps, (1,)
|
||||
).item()
|
||||
|
||||
current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
|
||||
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
pixel_height=height,
|
||||
pixel_width=width,
|
||||
batch_size=true_batch_size,
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
# get latents
|
||||
latents = noise * self.sd.noise_scheduler.init_noise_sigma
|
||||
latents = latents.to(self.device_torch, dtype=dtype)
|
||||
|
||||
with self.network:
|
||||
assert self.network.is_active
|
||||
# pass the multiplier list to the network
|
||||
self.network.multiplier = prompt_pair.multiplier_list
|
||||
denoised_latents = self.sd.diffuse_some_steps(
|
||||
latents, # pass simple noise latents
|
||||
train_tools.concat_prompt_embeddings(
|
||||
prompt_pair.positive_target, # unconditional
|
||||
prompt_pair.target_class, # target
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
start_timesteps=0,
|
||||
total_timesteps=timesteps_to,
|
||||
guidance_scale=3,
|
||||
)
|
||||
|
||||
noise_scheduler.set_timesteps(1000)
|
||||
|
||||
# split the latents into out prompt pair chunks
|
||||
denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0)
|
||||
denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks]
|
||||
|
||||
current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
|
||||
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
||||
|
||||
# flush() # 4.2GB to 3GB on 512x512
|
||||
|
||||
@@ -267,6 +284,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
current_timestep,
|
||||
denoised_latents
|
||||
)
|
||||
positive_latents = positive_latents.detach()
|
||||
positive_latents.requires_grad = False
|
||||
positive_latents_chunks = torch.chunk(positive_latents, self.prompt_chunk_size, dim=0)
|
||||
|
||||
@@ -277,6 +295,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
current_timestep,
|
||||
denoised_latents
|
||||
)
|
||||
neutral_latents = neutral_latents.detach()
|
||||
neutral_latents.requires_grad = False
|
||||
neutral_latents_chunks = torch.chunk(neutral_latents, self.prompt_chunk_size, dim=0)
|
||||
|
||||
@@ -287,9 +306,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
current_timestep,
|
||||
denoised_latents
|
||||
)
|
||||
unconditional_latents = unconditional_latents.detach()
|
||||
unconditional_latents.requires_grad = False
|
||||
unconditional_latents_chunks = torch.chunk(unconditional_latents, self.prompt_chunk_size, dim=0)
|
||||
|
||||
denoised_latents = denoised_latents.detach()
|
||||
|
||||
flush() # 4.2GB to 3GB on 512x512
|
||||
|
||||
# 4.20 GB RAM for 512x512
|
||||
@@ -402,10 +424,14 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# match batch size
|
||||
timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])]
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, self.train_config.min_snr_gamma)
|
||||
if from_batch:
|
||||
# match batch size
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
else:
|
||||
# match batch size
|
||||
timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])]
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean() * prompt_pair_chunk.weight
|
||||
|
||||
@@ -427,7 +453,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
positive_latents,
|
||||
neutral_latents,
|
||||
unconditional_latents,
|
||||
latents
|
||||
# latents
|
||||
)
|
||||
# move back to cpu
|
||||
prompt_pair.to("cpu")
|
||||
|
||||
@@ -77,6 +77,7 @@ class TrainConfig:
|
||||
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
||||
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
|
||||
self.weight_jitter = kwargs.get('weight_jitter', 0.0)
|
||||
self.merge_network_on_save = kwargs.get('merge_network_on_save', False)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
|
||||
@@ -101,14 +101,16 @@ if __name__ == '__main__':
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
user_path = os.path.expanduser('~')
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
input_path = os.path.join(user_path, "Pictures/sample_2_512.png")
|
||||
output_path = os.path.join(user_path, "Pictures/sample_2_512_llvae.png")
|
||||
img = Image.open(input_path)
|
||||
img_tensor = transforms.ToTensor()(img)
|
||||
img_tensor = img_tensor.unsqueeze(0)
|
||||
img_tensor = img_tensor.unsqueeze(0).to(device=device, dtype=dtype)
|
||||
print("input_shape: ", list(img_tensor.shape))
|
||||
vae = LosslessLatentVAE(in_channels=3, latent_depth=8)
|
||||
vae = LosslessLatentVAE(in_channels=3, latent_depth=8, dtype=dtype).to(device=device, dtype=dtype)
|
||||
latent = vae.encode(img_tensor)
|
||||
print("latent_shape: ", list(latent.shape))
|
||||
out_tensor = vae.decode(latent)
|
||||
|
||||
@@ -152,6 +152,11 @@ class StableDiffusion:
|
||||
steps_offset=1
|
||||
)
|
||||
|
||||
# move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why
|
||||
scheduler.betas = scheduler.betas.to(self.device_torch)
|
||||
scheduler.alphas = scheduler.alphas.to(self.device_torch)
|
||||
scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(self.device_torch)
|
||||
|
||||
model_path = self.model_config.name_or_path
|
||||
if 'civitai.com' in self.model_config.name_or_path:
|
||||
# load is a civit ai model, use the loader.
|
||||
@@ -323,6 +328,13 @@ class StableDiffusion:
|
||||
|
||||
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
|
||||
if self.is_xl:
|
||||
# fix guidance rescale for sdxl
|
||||
# was trained on 0.7 (I believe)
|
||||
|
||||
grs = gen_config.guidance_rescale
|
||||
if grs is None or grs < 0.00001:
|
||||
grs = 0.7
|
||||
|
||||
img = pipeline(
|
||||
prompt=gen_config.prompt,
|
||||
prompt_2=gen_config.prompt_2,
|
||||
@@ -332,7 +344,7 @@ class StableDiffusion:
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
guidance_rescale=gen_config.guidance_rescale,
|
||||
guidance_rescale=grs,
|
||||
).images[0]
|
||||
else:
|
||||
img = pipeline(
|
||||
@@ -619,6 +631,27 @@ class StableDiffusion:
|
||||
|
||||
return latents
|
||||
|
||||
def decode_latents(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
device=None,
|
||||
dtype=None
|
||||
):
|
||||
if device is None:
|
||||
device = self.device
|
||||
if dtype is None:
|
||||
dtype = self.torch_dtype
|
||||
|
||||
# Move to vae to device if on cpu
|
||||
if self.vae.device == 'cpu':
|
||||
self.vae.to(self.device)
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
latents = latents / 0.18215
|
||||
images = self.vae.decode(latents).sample
|
||||
images = images.to(device, dtype=dtype)
|
||||
|
||||
return images
|
||||
|
||||
def encode_image_prompt_pairs(
|
||||
self,
|
||||
prompt_list: List[str],
|
||||
|
||||
Reference in New Issue
Block a user