Varous bug fixes. Finalized targeted guidance algo

This commit is contained in:
Jaret Burkett
2023-11-10 12:18:08 -07:00
parent fa6d91ba76
commit 7782caa468
5 changed files with 92 additions and 138 deletions

View File

@@ -187,117 +187,81 @@ class SDTrainer(BaseSDTrainProcess):
**kwargs
):
with torch.no_grad():
conditional_noisy_latents = noisy_latents
# Perform targeted guidance (working title)
conditional_noisy_latents = noisy_latents # target images
dtype = get_torch_dtype(self.train_config.dtype)
if batch.unconditional_latents is not None:
# Encode the unconditional image into latents
# unconditional latents are the "neutral" images. Add noise here identical to
# the noise added to the conditional latents, at the same timesteps
unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(
batch.unconditional_latents, noise, timesteps
)
# was_network_active = self.network.is_active
# calculate the differential between our conditional (target image) and out unconditional (neutral image)
target_differential_noise = unconditional_noisy_latents - conditional_noisy_latents
target_differential_noise = target_differential_noise.detach()
# add the target differential to the target latents as if it were noise with the scheduler, scaled to
# the current timestep. Scaling the noise here is important as it scales our guidance to the current
# timestep. This is the key to making the guidance work.
guidance_latents = self.sd.noise_scheduler.add_noise(
conditional_noisy_latents,
target_differential_noise,
timesteps
)
# Disable the LoRA network so we can predict parent network knowledge without it
self.network.is_active = False
self.sd.unet.eval()
# calculate the differential between our conditional (target image) and out unconditional ("bad" image)
target_differential = unconditional_noisy_latents - conditional_noisy_latents
# scale the target differential by the scheduler
# todo, scale it the right way
# target_differential = self.sd.noise_scheduler.add_noise(
# torch.zeros_like(target_differential),
# target_differential,
# timesteps
# )
target_differential = target_differential.detach()
# add the target differential to the target latents as if it were noise with the scheduler scaled to
# the current timestep. Scaling the noise here is IMPORTANT and will lead to a blurry targeted area if not done
# properly
# guidance_latents = self.sd.noise_scheduler.add_noise(
# conditional_noisy_latents,
# target_differential,
# timesteps
# )
# guidance_latents = conditional_noisy_latents + target_differential
# target_noise = conditional_noisy_latents + target_differential
# With LoRA network bypassed, predict noise to get a baseline of what the network
# wants to do with the latents + noise. Pass our target latents here for the input.
target_unconditional = self.sd.predict_noise(
latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
# This acts as our control to preserve the unaltered parts of the image.
baseline_prediction = self.sd.predict_noise(
latents=guidance_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
).detach()
target_conditional = self.sd.predict_noise(
latents=conditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
).detach()
# we calculate the networks current knowledge so we do not overlearn what we know
current_knowledge = target_unconditional - target_conditional
# we now have the differential noise prediction needed to create our convergence target
target_unknown_knowledge = target_differential - current_knowledge
# turn the LoRA network back on.
self.sd.unet.train()
self.network.is_active = True
self.network.multiplier = network_weight_list
# with LoRA active, predict the noise with the scaled differential latents added. This will allow us
# the opportunity to predict the differential + noise that was added to the latents.
prediction_unconditional = self.sd.predict_noise(
latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
# do our prediction with LoRA active on the scaled guidance latents
prediction = self.sd.predict_noise(
latents=guidance_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
# remove the baseline conditional prediction. This will leave only the divergence from the baseline and
# the prediction of the added differential noise
# prediction_positive = prediction_unconditional - target_unconditional
prediction_positive = target_unconditional - prediction_unconditional
# remove the baseline prediction from our prediction to get the differential between the two
# all that should be left is the differential between the conditional and unconditional images
pred_differential_noise = prediction - baseline_prediction
# for loss, we target ONLY the unscaled differential between our conditional and unconditional latents
# this is the diffusion training process.
# not the timestep scaled noise that was added. This is the diffusion training process.
# This will guide the network to make identical predictions it previously did for everything EXCEPT our
# differential between the conditional and unconditional images
positive_loss = torch.nn.functional.mse_loss(
prediction_positive.float(),
target_unknown_knowledge.float(),
# differential between the conditional and unconditional images (target)
loss = torch.nn.functional.mse_loss(
pred_differential_noise.float(),
target_differential_noise.float(),
reduction="none"
)
# add adain loss
positive_loss = positive_loss
loss = loss.mean([1, 2, 3])
loss = self.apply_snr(loss, timesteps)
loss = loss.mean()
loss.backward()
positive_loss = positive_loss.mean([1, 2, 3])
# positive_loss = positive_loss + adain_loss.mean([1, 2, 3])
# send it backwards BEFORE switching network polarity
positive_loss = self.apply_snr(positive_loss, timesteps)
positive_loss = positive_loss.mean()
positive_loss.backward()
# loss = positive_loss.detach() + negative_loss.detach()
loss = positive_loss.detach()
# add a grad so other backward does not fail
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
loss.requires_grad_(True)
# restore network
self.network.multiplier = network_weight_list
return loss
def get_prior_prediction(

View File

@@ -1061,6 +1061,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
# load last saved weights
if latest_save_path is not None:
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
# self.step_num = self.embedding.step
# self.start_step = self.step_num
params.append({
'params': self.embedding.get_trainable_params(),
'lr': self.train_config.embedding_lr
@@ -1068,27 +1071,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
flush()
if self.embed_config is not None:
self.embedding = Embedding(
sd=self.sd,
embed_config=self.embed_config
)
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
# load last saved weights
if latest_save_path is not None:
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
# resume state from embedding
self.step_num = self.embedding.step
self.start_step = self.step_num
params = self.get_params()
if not params:
# set trainable params
params = self.embedding.get_trainable_params()
flush()
if self.adapter_config is not None:
self.setup_adapter()
# set trainable params

View File

@@ -327,6 +327,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
with torch.no_grad():
adapter_images = None
self.sd.unet.eval()
# for a complete slider, the batch size is 4 to begin with now
true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size
@@ -385,21 +386,22 @@ class TrainSliderProcess(BaseSDTrainProcess):
latents = noise * self.sd.noise_scheduler.init_noise_sigma
latents = latents.to(self.device_torch, dtype=dtype)
with self.network:
assert self.network.is_active
# pass the multiplier list to the network
self.network.multiplier = prompt_pair.multiplier_list
denoised_latents = self.sd.diffuse_some_steps(
latents, # pass simple noise latents
train_tools.concat_prompt_embeddings(
prompt_pair.positive_target, # unconditional
prompt_pair.target_class, # target
self.train_config.batch_size,
),
start_timesteps=0,
total_timesteps=timesteps_to,
guidance_scale=3,
)
assert not self.network.is_active
self.sd.unet.eval()
# pass the multiplier list to the network
self.network.multiplier = prompt_pair.multiplier_list
denoised_latents = self.sd.diffuse_some_steps(
latents, # pass simple noise latents
train_tools.concat_prompt_embeddings(
prompt_pair.positive_target, # unconditional
prompt_pair.target_class, # target
self.train_config.batch_size,
),
start_timesteps=0,
total_timesteps=timesteps_to,
guidance_scale=3,
)
noise_scheduler.set_timesteps(1000)
@@ -473,6 +475,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
denoised_latents = denoised_latents.detach()
self.sd.set_device_state(self.train_slider_device_state)
self.sd.unet.train()
# start accumulating gradients
self.optimizer.zero_grad(set_to_none=True)

View File

@@ -775,9 +775,9 @@ class PoiFileItemDTOMixin:
with open(caption_path, 'r', encoding='utf-8') as f:
json_data = json.load(f)
if 'poi' not in json_data:
raise Exception(f"Error: poi not found in caption file: {caption_path}")
print(f"Warning: poi not found in caption file: {caption_path}")
if self.poi not in json_data['poi']:
raise Exception(f"Error: poi not found in caption file: {caption_path}")
print(f"Warning: poi not found in caption file: {caption_path}")
# poi has, x, y, width, height
# do full image if no poi
self.poi_x = 0

View File

@@ -47,12 +47,15 @@ class Embedding:
self.placeholder_token_ids = []
self.embedding_tokens = []
print(f"Adding {placeholder_tokens} tokens to tokenizer")
print(f"Adding {self.embed_config.tokens} tokens to tokenizer")
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
if num_added_tokens != self.embed_config.tokens:
raise ValueError(
f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}"
)
# Convert the initializer_token, placeholder_token to ids
@@ -115,10 +118,10 @@ class Embedding:
def _set_vec(self, new_vector, text_encoder_idx=0):
# shape is (1, 768) for SD 1.5 for 1 token
token_embeds = self.text_encoder_list[0].get_input_embeddings().weight.data
token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data
for i in range(new_vector.shape[0]):
# apply the weights to the placeholder tokens while preserving gradient
token_embeds[self.placeholder_token_ids[0][i]] = new_vector[i].clone()
token_embeds[self.placeholder_token_ids[text_encoder_idx][i]] = new_vector[i].clone()
# make setter and getter for vec
@property
@@ -249,30 +252,32 @@ class Embedding:
else:
return
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict,
'_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
else:
raise Exception(
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
if 'step' in data:
self.step = int(data['step'])
if self.sd.is_xl:
self.vec = tensors['clip_l'].detach().to(device, dtype=torch.float32)
self.vec2 = tensors['clip_g'].detach().to(device, dtype=torch.float32)
if 'step' in data:
self.step = int(data['step'])
else:
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict,
'_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
else:
raise Exception(
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
if 'step' in data:
self.step = int(data['step'])
self.vec = emb.detach().to(device, dtype=torch.float32)