Moved SD batch processing to a shared method and added it for use in slider training. Still testing if it affects quality over sampling

This commit is contained in:
Jaret Burkett
2023-08-26 08:55:00 -06:00
parent aeaca13d69
commit 3367ab6b2c
6 changed files with 178 additions and 115 deletions

View File

@@ -36,67 +36,8 @@ class SDTrainer(BaseSDTrainProcess):
self.sd.text_encoder.train()
def hook_train_loop(self, batch):
with torch.no_grad():
imgs, prompts, dataset_config = batch
# convert the 0 or 1 for is reg to a bool list
is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])])
if isinstance(is_reg_list, torch.Tensor):
is_reg_list = is_reg_list.numpy().tolist()
is_reg_list = [bool(x) for x in is_reg_list]
conditioned_prompts = []
for prompt, is_reg in zip(prompts, is_reg_list):
# make sure the embedding is in the prompts
if self.embedding is not None:
prompt = self.embedding.inject_embedding_to_prompt(
prompt,
expand_token=True,
add_if_not_present=True,
)
# make sure trigger is in the prompts if not a regularization run
if self.trigger_word is not None and not is_reg:
prompt = self.sd.inject_trigger_into_prompt(
prompt,
add_if_not_present=True,
)
conditioned_prompts.append(prompt)
batch_size = imgs.shape[0]
dtype = get_torch_dtype(self.train_config.dtype)
imgs = imgs.to(self.device_torch, dtype=dtype)
latents = self.sd.encode_images(imgs)
noise_scheduler = self.sd.noise_scheduler
optimizer = self.optimizer
lr_scheduler = self.lr_scheduler
self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
)
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
timesteps = timesteps.long()
# get noise
noise = self.sd.get_latent_noise(
pixel_height=imgs.shape[2],
pixel_width=imgs.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset
).to(self.device_torch, dtype=dtype)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# remove grads for these
noisy_latents.requires_grad = False
noise.requires_grad = False
flush()
dtype = get_torch_dtype(self.train_config.dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
self.optimizer.zero_grad()
@@ -135,7 +76,7 @@ class SDTrainer(BaseSDTrainProcess):
if self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
@@ -144,7 +85,7 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
@@ -153,9 +94,9 @@ class SDTrainer(BaseSDTrainProcess):
flush()
# apply gradients
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
self.optimizer.step()
self.optimizer.zero_grad()
self.lr_scheduler.step()
if self.embedding is not None:
# Let's make sure we don't update any embedding weights besides the newly added token

View File

@@ -283,6 +283,66 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
print("load_weights not implemented for non-network models")
def process_general_training_batch(self, batch):
with torch.no_grad():
imgs, prompts, dataset_config = batch
# convert the 0 or 1 for is reg to a bool list
is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])])
if isinstance(is_reg_list, torch.Tensor):
is_reg_list = is_reg_list.numpy().tolist()
is_reg_list = [bool(x) for x in is_reg_list]
conditioned_prompts = []
for prompt, is_reg in zip(prompts, is_reg_list):
# make sure the embedding is in the prompts
if self.embedding is not None:
prompt = self.embedding.inject_embedding_to_prompt(
prompt,
expand_token=True,
add_if_not_present=True,
)
# make sure trigger is in the prompts if not a regularization run
if self.trigger_word is not None and not is_reg:
prompt = self.sd.inject_trigger_into_prompt(
prompt,
add_if_not_present=True,
)
conditioned_prompts.append(prompt)
batch_size = imgs.shape[0]
dtype = get_torch_dtype(self.train_config.dtype)
imgs = imgs.to(self.device_torch, dtype=dtype)
latents = self.sd.encode_images(imgs)
self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
)
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
timesteps = timesteps.long()
# get noise
noise = self.sd.get_latent_noise(
pixel_height=imgs.shape[2],
pixel_width=imgs.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset
).to(self.device_torch, dtype=dtype)
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
# remove grads for these
noisy_latents.requires_grad = False
noise.requires_grad = False
return noisy_latents, noise, timesteps, conditioned_prompts, imgs
def run(self):
# run base process run
BaseTrainProcess.run(self)

View File

@@ -169,11 +169,16 @@ class TrainSliderProcess(BaseSDTrainProcess):
self.prompt_pairs = prompt_pairs
# self.anchor_pairs = anchor_pairs
flush()
if self.data_loader is not None:
# we will have images, prep the vae
self.sd.vae.eval()
self.sd.vae.to(self.device_torch)
# end hook_before_train_loop
def hook_train_loop(self, batch):
dtype = get_torch_dtype(self.train_config.dtype)
# get a random pair
prompt_pair: EncodedPromptPair = self.prompt_pairs[
torch.randint(0, len(self.prompt_pairs), (1,)).item()
@@ -207,55 +212,67 @@ class TrainSliderProcess(BaseSDTrainProcess):
)
with torch.no_grad():
self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
)
self.optimizer.zero_grad()
# ger a random number of steps
timesteps_to = torch.randint(
1, self.train_config.max_denoising_steps, (1,)
).item()
# for a complete slider, the batch size is 4 to begin with now
true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size
from_batch = False
if batch is not None:
# traing from a batch of images, not generating ourselves
from_batch = True
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
# get noise
noise = self.sd.get_latent_noise(
pixel_height=height,
pixel_width=width,
batch_size=true_batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
denoised_latent_chunks = [noisy_latents] * self.prompt_chunk_size
denoised_latents = torch.cat(denoised_latent_chunks, dim=0)
current_timestep = timesteps
else:
# get latents
latents = noise * self.sd.noise_scheduler.init_noise_sigma
latents = latents.to(self.device_torch, dtype=dtype)
with self.network:
assert self.network.is_active
# pass the multiplier list to the network
self.network.multiplier = prompt_pair.multiplier_list
denoised_latents = self.sd.diffuse_some_steps(
latents, # pass simple noise latents
train_tools.concat_prompt_embeddings(
prompt_pair.positive_target, # unconditional
prompt_pair.target_class, # target
self.train_config.batch_size,
),
start_timesteps=0,
total_timesteps=timesteps_to,
guidance_scale=3,
self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
)
# split the latents into out prompt pair chunks
denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0)
self.optimizer.zero_grad()
noise_scheduler.set_timesteps(1000)
# ger a random number of steps
timesteps_to = torch.randint(
1, self.train_config.max_denoising_steps, (1,)
).item()
current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
current_timestep = noise_scheduler.timesteps[current_timestep_index]
# get noise
noise = self.sd.get_latent_noise(
pixel_height=height,
pixel_width=width,
batch_size=true_batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
# get latents
latents = noise * self.sd.noise_scheduler.init_noise_sigma
latents = latents.to(self.device_torch, dtype=dtype)
with self.network:
assert self.network.is_active
# pass the multiplier list to the network
self.network.multiplier = prompt_pair.multiplier_list
denoised_latents = self.sd.diffuse_some_steps(
latents, # pass simple noise latents
train_tools.concat_prompt_embeddings(
prompt_pair.positive_target, # unconditional
prompt_pair.target_class, # target
self.train_config.batch_size,
),
start_timesteps=0,
total_timesteps=timesteps_to,
guidance_scale=3,
)
noise_scheduler.set_timesteps(1000)
# split the latents into out prompt pair chunks
denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0)
denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks]
current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
current_timestep = noise_scheduler.timesteps[current_timestep_index]
# flush() # 4.2GB to 3GB on 512x512
@@ -267,6 +284,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
current_timestep,
denoised_latents
)
positive_latents = positive_latents.detach()
positive_latents.requires_grad = False
positive_latents_chunks = torch.chunk(positive_latents, self.prompt_chunk_size, dim=0)
@@ -277,6 +295,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
current_timestep,
denoised_latents
)
neutral_latents = neutral_latents.detach()
neutral_latents.requires_grad = False
neutral_latents_chunks = torch.chunk(neutral_latents, self.prompt_chunk_size, dim=0)
@@ -287,9 +306,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
current_timestep,
denoised_latents
)
unconditional_latents = unconditional_latents.detach()
unconditional_latents.requires_grad = False
unconditional_latents_chunks = torch.chunk(unconditional_latents, self.prompt_chunk_size, dim=0)
denoised_latents = denoised_latents.detach()
flush() # 4.2GB to 3GB on 512x512
# 4.20 GB RAM for 512x512
@@ -402,10 +424,14 @@ class TrainSliderProcess(BaseSDTrainProcess):
loss = loss.mean([1, 2, 3])
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# match batch size
timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])]
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, self.train_config.min_snr_gamma)
if from_batch:
# match batch size
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
else:
# match batch size
timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])]
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean() * prompt_pair_chunk.weight
@@ -427,7 +453,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
positive_latents,
neutral_latents,
unconditional_latents,
latents
# latents
)
# move back to cpu
prompt_pair.to("cpu")

View File

@@ -77,6 +77,7 @@ class TrainConfig:
self.skip_first_sample = kwargs.get('skip_first_sample', False)
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
self.weight_jitter = kwargs.get('weight_jitter', 0.0)
self.merge_network_on_save = kwargs.get('merge_network_on_save', False)
class ModelConfig:

View File

@@ -101,14 +101,16 @@ if __name__ == '__main__':
from PIL import Image
import torchvision.transforms as transforms
user_path = os.path.expanduser('~')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
input_path = os.path.join(user_path, "Pictures/sample_2_512.png")
output_path = os.path.join(user_path, "Pictures/sample_2_512_llvae.png")
img = Image.open(input_path)
img_tensor = transforms.ToTensor()(img)
img_tensor = img_tensor.unsqueeze(0)
img_tensor = img_tensor.unsqueeze(0).to(device=device, dtype=dtype)
print("input_shape: ", list(img_tensor.shape))
vae = LosslessLatentVAE(in_channels=3, latent_depth=8)
vae = LosslessLatentVAE(in_channels=3, latent_depth=8, dtype=dtype).to(device=device, dtype=dtype)
latent = vae.encode(img_tensor)
print("latent_shape: ", list(latent.shape))
out_tensor = vae.decode(latent)

View File

@@ -152,6 +152,11 @@ class StableDiffusion:
steps_offset=1
)
# move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why
scheduler.betas = scheduler.betas.to(self.device_torch)
scheduler.alphas = scheduler.alphas.to(self.device_torch)
scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(self.device_torch)
model_path = self.model_config.name_or_path
if 'civitai.com' in self.model_config.name_or_path:
# load is a civit ai model, use the loader.
@@ -323,6 +328,13 @@ class StableDiffusion:
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
if self.is_xl:
# fix guidance rescale for sdxl
# was trained on 0.7 (I believe)
grs = gen_config.guidance_rescale
if grs is None or grs < 0.00001:
grs = 0.7
img = pipeline(
prompt=gen_config.prompt,
prompt_2=gen_config.prompt_2,
@@ -332,7 +344,7 @@ class StableDiffusion:
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
guidance_rescale=gen_config.guidance_rescale,
guidance_rescale=grs,
).images[0]
else:
img = pipeline(
@@ -619,6 +631,27 @@ class StableDiffusion:
return latents
def decode_latents(
self,
latents: torch.Tensor,
device=None,
dtype=None
):
if device is None:
device = self.device
if dtype is None:
dtype = self.torch_dtype
# Move to vae to device if on cpu
if self.vae.device == 'cpu':
self.vae.to(self.device)
latents = latents.to(device, dtype=dtype)
latents = latents / 0.18215
images = self.vae.decode(latents).sample
images = images.to(device, dtype=dtype)
return images
def encode_image_prompt_pairs(
self,
prompt_list: List[str],