mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Varous bug fixes. Finalized targeted guidance algo
This commit is contained in:
@@ -187,117 +187,81 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
**kwargs
|
||||
):
|
||||
with torch.no_grad():
|
||||
conditional_noisy_latents = noisy_latents
|
||||
# Perform targeted guidance (working title)
|
||||
conditional_noisy_latents = noisy_latents # target images
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
if batch.unconditional_latents is not None:
|
||||
# Encode the unconditional image into latents
|
||||
# unconditional latents are the "neutral" images. Add noise here identical to
|
||||
# the noise added to the conditional latents, at the same timesteps
|
||||
unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(
|
||||
batch.unconditional_latents, noise, timesteps
|
||||
)
|
||||
|
||||
# was_network_active = self.network.is_active
|
||||
# calculate the differential between our conditional (target image) and out unconditional (neutral image)
|
||||
target_differential_noise = unconditional_noisy_latents - conditional_noisy_latents
|
||||
target_differential_noise = target_differential_noise.detach()
|
||||
|
||||
# add the target differential to the target latents as if it were noise with the scheduler, scaled to
|
||||
# the current timestep. Scaling the noise here is important as it scales our guidance to the current
|
||||
# timestep. This is the key to making the guidance work.
|
||||
guidance_latents = self.sd.noise_scheduler.add_noise(
|
||||
conditional_noisy_latents,
|
||||
target_differential_noise,
|
||||
timesteps
|
||||
)
|
||||
|
||||
# Disable the LoRA network so we can predict parent network knowledge without it
|
||||
self.network.is_active = False
|
||||
self.sd.unet.eval()
|
||||
|
||||
# calculate the differential between our conditional (target image) and out unconditional ("bad" image)
|
||||
target_differential = unconditional_noisy_latents - conditional_noisy_latents
|
||||
|
||||
# scale the target differential by the scheduler
|
||||
# todo, scale it the right way
|
||||
# target_differential = self.sd.noise_scheduler.add_noise(
|
||||
# torch.zeros_like(target_differential),
|
||||
# target_differential,
|
||||
# timesteps
|
||||
# )
|
||||
|
||||
target_differential = target_differential.detach()
|
||||
|
||||
# add the target differential to the target latents as if it were noise with the scheduler scaled to
|
||||
# the current timestep. Scaling the noise here is IMPORTANT and will lead to a blurry targeted area if not done
|
||||
# properly
|
||||
# guidance_latents = self.sd.noise_scheduler.add_noise(
|
||||
# conditional_noisy_latents,
|
||||
# target_differential,
|
||||
# timesteps
|
||||
# )
|
||||
|
||||
# guidance_latents = conditional_noisy_latents + target_differential
|
||||
# target_noise = conditional_noisy_latents + target_differential
|
||||
|
||||
# With LoRA network bypassed, predict noise to get a baseline of what the network
|
||||
# wants to do with the latents + noise. Pass our target latents here for the input.
|
||||
target_unconditional = self.sd.predict_noise(
|
||||
latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
||||
# This acts as our control to preserve the unaltered parts of the image.
|
||||
baseline_prediction = self.sd.predict_noise(
|
||||
latents=guidance_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
).detach()
|
||||
target_conditional = self.sd.predict_noise(
|
||||
latents=conditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
).detach()
|
||||
|
||||
# we calculate the networks current knowledge so we do not overlearn what we know
|
||||
current_knowledge = target_unconditional - target_conditional
|
||||
|
||||
# we now have the differential noise prediction needed to create our convergence target
|
||||
target_unknown_knowledge = target_differential - current_knowledge
|
||||
|
||||
# turn the LoRA network back on.
|
||||
self.sd.unet.train()
|
||||
self.network.is_active = True
|
||||
self.network.multiplier = network_weight_list
|
||||
|
||||
# with LoRA active, predict the noise with the scaled differential latents added. This will allow us
|
||||
# the opportunity to predict the differential + noise that was added to the latents.
|
||||
prediction_unconditional = self.sd.predict_noise(
|
||||
latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
# do our prediction with LoRA active on the scaled guidance latents
|
||||
prediction = self.sd.predict_noise(
|
||||
latents=guidance_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
|
||||
# remove the baseline conditional prediction. This will leave only the divergence from the baseline and
|
||||
# the prediction of the added differential noise
|
||||
# prediction_positive = prediction_unconditional - target_unconditional
|
||||
prediction_positive = target_unconditional - prediction_unconditional
|
||||
# remove the baseline prediction from our prediction to get the differential between the two
|
||||
# all that should be left is the differential between the conditional and unconditional images
|
||||
pred_differential_noise = prediction - baseline_prediction
|
||||
|
||||
# for loss, we target ONLY the unscaled differential between our conditional and unconditional latents
|
||||
# this is the diffusion training process.
|
||||
# not the timestep scaled noise that was added. This is the diffusion training process.
|
||||
# This will guide the network to make identical predictions it previously did for everything EXCEPT our
|
||||
# differential between the conditional and unconditional images
|
||||
|
||||
positive_loss = torch.nn.functional.mse_loss(
|
||||
prediction_positive.float(),
|
||||
target_unknown_knowledge.float(),
|
||||
# differential between the conditional and unconditional images (target)
|
||||
loss = torch.nn.functional.mse_loss(
|
||||
pred_differential_noise.float(),
|
||||
target_differential_noise.float(),
|
||||
reduction="none"
|
||||
)
|
||||
|
||||
# add adain loss
|
||||
positive_loss = positive_loss
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = self.apply_snr(loss, timesteps)
|
||||
loss = loss.mean()
|
||||
loss.backward()
|
||||
|
||||
positive_loss = positive_loss.mean([1, 2, 3])
|
||||
|
||||
# positive_loss = positive_loss + adain_loss.mean([1, 2, 3])
|
||||
# send it backwards BEFORE switching network polarity
|
||||
positive_loss = self.apply_snr(positive_loss, timesteps)
|
||||
positive_loss = positive_loss.mean()
|
||||
positive_loss.backward()
|
||||
# loss = positive_loss.detach() + negative_loss.detach()
|
||||
loss = positive_loss.detach()
|
||||
|
||||
# add a grad so other backward does not fail
|
||||
# detach it so parent class can run backward on no grads without throwing error
|
||||
loss = loss.detach()
|
||||
loss.requires_grad_(True)
|
||||
|
||||
# restore network
|
||||
self.network.multiplier = network_weight_list
|
||||
|
||||
return loss
|
||||
|
||||
def get_prior_prediction(
|
||||
|
||||
@@ -1061,6 +1061,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# load last saved weights
|
||||
if latest_save_path is not None:
|
||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
||||
|
||||
# self.step_num = self.embedding.step
|
||||
# self.start_step = self.step_num
|
||||
params.append({
|
||||
'params': self.embedding.get_trainable_params(),
|
||||
'lr': self.train_config.embedding_lr
|
||||
@@ -1068,27 +1071,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
flush()
|
||||
|
||||
if self.embed_config is not None:
|
||||
self.embedding = Embedding(
|
||||
sd=self.sd,
|
||||
embed_config=self.embed_config
|
||||
)
|
||||
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
|
||||
# load last saved weights
|
||||
if latest_save_path is not None:
|
||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
||||
|
||||
# resume state from embedding
|
||||
self.step_num = self.embedding.step
|
||||
self.start_step = self.step_num
|
||||
|
||||
params = self.get_params()
|
||||
if not params:
|
||||
# set trainable params
|
||||
params = self.embedding.get_trainable_params()
|
||||
|
||||
flush()
|
||||
|
||||
if self.adapter_config is not None:
|
||||
self.setup_adapter()
|
||||
# set trainable params
|
||||
|
||||
@@ -327,6 +327,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
|
||||
with torch.no_grad():
|
||||
adapter_images = None
|
||||
self.sd.unet.eval()
|
||||
|
||||
# for a complete slider, the batch size is 4 to begin with now
|
||||
true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size
|
||||
@@ -385,21 +386,22 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
latents = noise * self.sd.noise_scheduler.init_noise_sigma
|
||||
latents = latents.to(self.device_torch, dtype=dtype)
|
||||
|
||||
with self.network:
|
||||
assert self.network.is_active
|
||||
# pass the multiplier list to the network
|
||||
self.network.multiplier = prompt_pair.multiplier_list
|
||||
denoised_latents = self.sd.diffuse_some_steps(
|
||||
latents, # pass simple noise latents
|
||||
train_tools.concat_prompt_embeddings(
|
||||
prompt_pair.positive_target, # unconditional
|
||||
prompt_pair.target_class, # target
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
start_timesteps=0,
|
||||
total_timesteps=timesteps_to,
|
||||
guidance_scale=3,
|
||||
)
|
||||
assert not self.network.is_active
|
||||
self.sd.unet.eval()
|
||||
# pass the multiplier list to the network
|
||||
self.network.multiplier = prompt_pair.multiplier_list
|
||||
denoised_latents = self.sd.diffuse_some_steps(
|
||||
latents, # pass simple noise latents
|
||||
train_tools.concat_prompt_embeddings(
|
||||
prompt_pair.positive_target, # unconditional
|
||||
prompt_pair.target_class, # target
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
start_timesteps=0,
|
||||
total_timesteps=timesteps_to,
|
||||
guidance_scale=3,
|
||||
)
|
||||
|
||||
|
||||
noise_scheduler.set_timesteps(1000)
|
||||
|
||||
@@ -473,6 +475,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
denoised_latents = denoised_latents.detach()
|
||||
|
||||
self.sd.set_device_state(self.train_slider_device_state)
|
||||
self.sd.unet.train()
|
||||
# start accumulating gradients
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
|
||||
@@ -775,9 +775,9 @@ class PoiFileItemDTOMixin:
|
||||
with open(caption_path, 'r', encoding='utf-8') as f:
|
||||
json_data = json.load(f)
|
||||
if 'poi' not in json_data:
|
||||
raise Exception(f"Error: poi not found in caption file: {caption_path}")
|
||||
print(f"Warning: poi not found in caption file: {caption_path}")
|
||||
if self.poi not in json_data['poi']:
|
||||
raise Exception(f"Error: poi not found in caption file: {caption_path}")
|
||||
print(f"Warning: poi not found in caption file: {caption_path}")
|
||||
# poi has, x, y, width, height
|
||||
# do full image if no poi
|
||||
self.poi_x = 0
|
||||
|
||||
@@ -47,12 +47,15 @@ class Embedding:
|
||||
self.placeholder_token_ids = []
|
||||
self.embedding_tokens = []
|
||||
|
||||
print(f"Adding {placeholder_tokens} tokens to tokenizer")
|
||||
print(f"Adding {self.embed_config.tokens} tokens to tokenizer")
|
||||
|
||||
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
|
||||
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
|
||||
if num_added_tokens != self.embed_config.tokens:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}"
|
||||
)
|
||||
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
@@ -115,10 +118,10 @@ class Embedding:
|
||||
|
||||
def _set_vec(self, new_vector, text_encoder_idx=0):
|
||||
# shape is (1, 768) for SD 1.5 for 1 token
|
||||
token_embeds = self.text_encoder_list[0].get_input_embeddings().weight.data
|
||||
token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data
|
||||
for i in range(new_vector.shape[0]):
|
||||
# apply the weights to the placeholder tokens while preserving gradient
|
||||
token_embeds[self.placeholder_token_ids[0][i]] = new_vector[i].clone()
|
||||
token_embeds[self.placeholder_token_ids[text_encoder_idx][i]] = new_vector[i].clone()
|
||||
|
||||
# make setter and getter for vec
|
||||
@property
|
||||
@@ -249,30 +252,32 @@ class Embedding:
|
||||
else:
|
||||
return
|
||||
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
if hasattr(param_dict, '_parameters'):
|
||||
param_dict = getattr(param_dict,
|
||||
'_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||
|
||||
emb = next(iter(data.values()))
|
||||
if len(emb.shape) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||
|
||||
if 'step' in data:
|
||||
self.step = int(data['step'])
|
||||
|
||||
if self.sd.is_xl:
|
||||
self.vec = tensors['clip_l'].detach().to(device, dtype=torch.float32)
|
||||
self.vec2 = tensors['clip_g'].detach().to(device, dtype=torch.float32)
|
||||
if 'step' in data:
|
||||
self.step = int(data['step'])
|
||||
else:
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
if hasattr(param_dict, '_parameters'):
|
||||
param_dict = getattr(param_dict,
|
||||
'_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||
|
||||
emb = next(iter(data.values()))
|
||||
if len(emb.shape) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||
|
||||
if 'step' in data:
|
||||
self.step = int(data['step'])
|
||||
|
||||
self.vec = emb.detach().to(device, dtype=torch.float32)
|
||||
|
||||
Reference in New Issue
Block a user