mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Varous bug fixes. Finalized targeted guidance algo
This commit is contained in:
@@ -187,117 +187,81 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
conditional_noisy_latents = noisy_latents
|
# Perform targeted guidance (working title)
|
||||||
|
conditional_noisy_latents = noisy_latents # target images
|
||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
|
|
||||||
if batch.unconditional_latents is not None:
|
if batch.unconditional_latents is not None:
|
||||||
# Encode the unconditional image into latents
|
# unconditional latents are the "neutral" images. Add noise here identical to
|
||||||
|
# the noise added to the conditional latents, at the same timesteps
|
||||||
unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(
|
unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(
|
||||||
batch.unconditional_latents, noise, timesteps
|
batch.unconditional_latents, noise, timesteps
|
||||||
)
|
)
|
||||||
|
|
||||||
# was_network_active = self.network.is_active
|
# calculate the differential between our conditional (target image) and out unconditional (neutral image)
|
||||||
|
target_differential_noise = unconditional_noisy_latents - conditional_noisy_latents
|
||||||
|
target_differential_noise = target_differential_noise.detach()
|
||||||
|
|
||||||
|
# add the target differential to the target latents as if it were noise with the scheduler, scaled to
|
||||||
|
# the current timestep. Scaling the noise here is important as it scales our guidance to the current
|
||||||
|
# timestep. This is the key to making the guidance work.
|
||||||
|
guidance_latents = self.sd.noise_scheduler.add_noise(
|
||||||
|
conditional_noisy_latents,
|
||||||
|
target_differential_noise,
|
||||||
|
timesteps
|
||||||
|
)
|
||||||
|
|
||||||
|
# Disable the LoRA network so we can predict parent network knowledge without it
|
||||||
self.network.is_active = False
|
self.network.is_active = False
|
||||||
self.sd.unet.eval()
|
self.sd.unet.eval()
|
||||||
|
|
||||||
# calculate the differential between our conditional (target image) and out unconditional ("bad" image)
|
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
||||||
target_differential = unconditional_noisy_latents - conditional_noisy_latents
|
# This acts as our control to preserve the unaltered parts of the image.
|
||||||
|
baseline_prediction = self.sd.predict_noise(
|
||||||
# scale the target differential by the scheduler
|
latents=guidance_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||||
# todo, scale it the right way
|
|
||||||
# target_differential = self.sd.noise_scheduler.add_noise(
|
|
||||||
# torch.zeros_like(target_differential),
|
|
||||||
# target_differential,
|
|
||||||
# timesteps
|
|
||||||
# )
|
|
||||||
|
|
||||||
target_differential = target_differential.detach()
|
|
||||||
|
|
||||||
# add the target differential to the target latents as if it were noise with the scheduler scaled to
|
|
||||||
# the current timestep. Scaling the noise here is IMPORTANT and will lead to a blurry targeted area if not done
|
|
||||||
# properly
|
|
||||||
# guidance_latents = self.sd.noise_scheduler.add_noise(
|
|
||||||
# conditional_noisy_latents,
|
|
||||||
# target_differential,
|
|
||||||
# timesteps
|
|
||||||
# )
|
|
||||||
|
|
||||||
# guidance_latents = conditional_noisy_latents + target_differential
|
|
||||||
# target_noise = conditional_noisy_latents + target_differential
|
|
||||||
|
|
||||||
# With LoRA network bypassed, predict noise to get a baseline of what the network
|
|
||||||
# wants to do with the latents + noise. Pass our target latents here for the input.
|
|
||||||
target_unconditional = self.sd.predict_noise(
|
|
||||||
latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
|
||||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||||
timestep=timesteps,
|
timestep=timesteps,
|
||||||
guidance_scale=1.0,
|
guidance_scale=1.0,
|
||||||
**pred_kwargs # adapter residuals in here
|
**pred_kwargs # adapter residuals in here
|
||||||
).detach()
|
).detach()
|
||||||
target_conditional = self.sd.predict_noise(
|
|
||||||
latents=conditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
|
||||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
|
||||||
timestep=timesteps,
|
|
||||||
guidance_scale=1.0,
|
|
||||||
**pred_kwargs # adapter residuals in here
|
|
||||||
).detach()
|
|
||||||
|
|
||||||
# we calculate the networks current knowledge so we do not overlearn what we know
|
|
||||||
current_knowledge = target_unconditional - target_conditional
|
|
||||||
|
|
||||||
# we now have the differential noise prediction needed to create our convergence target
|
|
||||||
target_unknown_knowledge = target_differential - current_knowledge
|
|
||||||
|
|
||||||
# turn the LoRA network back on.
|
# turn the LoRA network back on.
|
||||||
self.sd.unet.train()
|
self.sd.unet.train()
|
||||||
self.network.is_active = True
|
self.network.is_active = True
|
||||||
self.network.multiplier = network_weight_list
|
self.network.multiplier = network_weight_list
|
||||||
|
|
||||||
# with LoRA active, predict the noise with the scaled differential latents added. This will allow us
|
# do our prediction with LoRA active on the scaled guidance latents
|
||||||
# the opportunity to predict the differential + noise that was added to the latents.
|
prediction = self.sd.predict_noise(
|
||||||
prediction_unconditional = self.sd.predict_noise(
|
latents=guidance_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||||
latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
|
||||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||||
timestep=timesteps,
|
timestep=timesteps,
|
||||||
guidance_scale=1.0,
|
guidance_scale=1.0,
|
||||||
**pred_kwargs # adapter residuals in here
|
**pred_kwargs # adapter residuals in here
|
||||||
)
|
)
|
||||||
|
|
||||||
# remove the baseline conditional prediction. This will leave only the divergence from the baseline and
|
# remove the baseline prediction from our prediction to get the differential between the two
|
||||||
# the prediction of the added differential noise
|
# all that should be left is the differential between the conditional and unconditional images
|
||||||
# prediction_positive = prediction_unconditional - target_unconditional
|
pred_differential_noise = prediction - baseline_prediction
|
||||||
prediction_positive = target_unconditional - prediction_unconditional
|
|
||||||
|
|
||||||
# for loss, we target ONLY the unscaled differential between our conditional and unconditional latents
|
# for loss, we target ONLY the unscaled differential between our conditional and unconditional latents
|
||||||
# this is the diffusion training process.
|
# not the timestep scaled noise that was added. This is the diffusion training process.
|
||||||
# This will guide the network to make identical predictions it previously did for everything EXCEPT our
|
# This will guide the network to make identical predictions it previously did for everything EXCEPT our
|
||||||
# differential between the conditional and unconditional images
|
# differential between the conditional and unconditional images (target)
|
||||||
|
loss = torch.nn.functional.mse_loss(
|
||||||
positive_loss = torch.nn.functional.mse_loss(
|
pred_differential_noise.float(),
|
||||||
prediction_positive.float(),
|
target_differential_noise.float(),
|
||||||
target_unknown_knowledge.float(),
|
|
||||||
reduction="none"
|
reduction="none"
|
||||||
)
|
)
|
||||||
|
|
||||||
# add adain loss
|
loss = loss.mean([1, 2, 3])
|
||||||
positive_loss = positive_loss
|
loss = self.apply_snr(loss, timesteps)
|
||||||
|
loss = loss.mean()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
positive_loss = positive_loss.mean([1, 2, 3])
|
# detach it so parent class can run backward on no grads without throwing error
|
||||||
|
loss = loss.detach()
|
||||||
# positive_loss = positive_loss + adain_loss.mean([1, 2, 3])
|
|
||||||
# send it backwards BEFORE switching network polarity
|
|
||||||
positive_loss = self.apply_snr(positive_loss, timesteps)
|
|
||||||
positive_loss = positive_loss.mean()
|
|
||||||
positive_loss.backward()
|
|
||||||
# loss = positive_loss.detach() + negative_loss.detach()
|
|
||||||
loss = positive_loss.detach()
|
|
||||||
|
|
||||||
# add a grad so other backward does not fail
|
|
||||||
loss.requires_grad_(True)
|
loss.requires_grad_(True)
|
||||||
|
|
||||||
# restore network
|
|
||||||
self.network.multiplier = network_weight_list
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def get_prior_prediction(
|
def get_prior_prediction(
|
||||||
|
|||||||
@@ -1061,6 +1061,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# load last saved weights
|
# load last saved weights
|
||||||
if latest_save_path is not None:
|
if latest_save_path is not None:
|
||||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
||||||
|
|
||||||
|
# self.step_num = self.embedding.step
|
||||||
|
# self.start_step = self.step_num
|
||||||
params.append({
|
params.append({
|
||||||
'params': self.embedding.get_trainable_params(),
|
'params': self.embedding.get_trainable_params(),
|
||||||
'lr': self.train_config.embedding_lr
|
'lr': self.train_config.embedding_lr
|
||||||
@@ -1068,27 +1071,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
if self.embed_config is not None:
|
|
||||||
self.embedding = Embedding(
|
|
||||||
sd=self.sd,
|
|
||||||
embed_config=self.embed_config
|
|
||||||
)
|
|
||||||
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
|
|
||||||
# load last saved weights
|
|
||||||
if latest_save_path is not None:
|
|
||||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
|
||||||
|
|
||||||
# resume state from embedding
|
|
||||||
self.step_num = self.embedding.step
|
|
||||||
self.start_step = self.step_num
|
|
||||||
|
|
||||||
params = self.get_params()
|
|
||||||
if not params:
|
|
||||||
# set trainable params
|
|
||||||
params = self.embedding.get_trainable_params()
|
|
||||||
|
|
||||||
flush()
|
|
||||||
|
|
||||||
if self.adapter_config is not None:
|
if self.adapter_config is not None:
|
||||||
self.setup_adapter()
|
self.setup_adapter()
|
||||||
# set trainable params
|
# set trainable params
|
||||||
|
|||||||
@@ -327,6 +327,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
adapter_images = None
|
adapter_images = None
|
||||||
|
self.sd.unet.eval()
|
||||||
|
|
||||||
# for a complete slider, the batch size is 4 to begin with now
|
# for a complete slider, the batch size is 4 to begin with now
|
||||||
true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size
|
true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size
|
||||||
@@ -385,21 +386,22 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
latents = noise * self.sd.noise_scheduler.init_noise_sigma
|
latents = noise * self.sd.noise_scheduler.init_noise_sigma
|
||||||
latents = latents.to(self.device_torch, dtype=dtype)
|
latents = latents.to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
with self.network:
|
assert not self.network.is_active
|
||||||
assert self.network.is_active
|
self.sd.unet.eval()
|
||||||
# pass the multiplier list to the network
|
# pass the multiplier list to the network
|
||||||
self.network.multiplier = prompt_pair.multiplier_list
|
self.network.multiplier = prompt_pair.multiplier_list
|
||||||
denoised_latents = self.sd.diffuse_some_steps(
|
denoised_latents = self.sd.diffuse_some_steps(
|
||||||
latents, # pass simple noise latents
|
latents, # pass simple noise latents
|
||||||
train_tools.concat_prompt_embeddings(
|
train_tools.concat_prompt_embeddings(
|
||||||
prompt_pair.positive_target, # unconditional
|
prompt_pair.positive_target, # unconditional
|
||||||
prompt_pair.target_class, # target
|
prompt_pair.target_class, # target
|
||||||
self.train_config.batch_size,
|
self.train_config.batch_size,
|
||||||
),
|
),
|
||||||
start_timesteps=0,
|
start_timesteps=0,
|
||||||
total_timesteps=timesteps_to,
|
total_timesteps=timesteps_to,
|
||||||
guidance_scale=3,
|
guidance_scale=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
noise_scheduler.set_timesteps(1000)
|
noise_scheduler.set_timesteps(1000)
|
||||||
|
|
||||||
@@ -473,6 +475,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
denoised_latents = denoised_latents.detach()
|
denoised_latents = denoised_latents.detach()
|
||||||
|
|
||||||
self.sd.set_device_state(self.train_slider_device_state)
|
self.sd.set_device_state(self.train_slider_device_state)
|
||||||
|
self.sd.unet.train()
|
||||||
# start accumulating gradients
|
# start accumulating gradients
|
||||||
self.optimizer.zero_grad(set_to_none=True)
|
self.optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
|||||||
@@ -775,9 +775,9 @@ class PoiFileItemDTOMixin:
|
|||||||
with open(caption_path, 'r', encoding='utf-8') as f:
|
with open(caption_path, 'r', encoding='utf-8') as f:
|
||||||
json_data = json.load(f)
|
json_data = json.load(f)
|
||||||
if 'poi' not in json_data:
|
if 'poi' not in json_data:
|
||||||
raise Exception(f"Error: poi not found in caption file: {caption_path}")
|
print(f"Warning: poi not found in caption file: {caption_path}")
|
||||||
if self.poi not in json_data['poi']:
|
if self.poi not in json_data['poi']:
|
||||||
raise Exception(f"Error: poi not found in caption file: {caption_path}")
|
print(f"Warning: poi not found in caption file: {caption_path}")
|
||||||
# poi has, x, y, width, height
|
# poi has, x, y, width, height
|
||||||
# do full image if no poi
|
# do full image if no poi
|
||||||
self.poi_x = 0
|
self.poi_x = 0
|
||||||
|
|||||||
@@ -47,12 +47,15 @@ class Embedding:
|
|||||||
self.placeholder_token_ids = []
|
self.placeholder_token_ids = []
|
||||||
self.embedding_tokens = []
|
self.embedding_tokens = []
|
||||||
|
|
||||||
|
print(f"Adding {placeholder_tokens} tokens to tokenizer")
|
||||||
|
print(f"Adding {self.embed_config.tokens} tokens to tokenizer")
|
||||||
|
|
||||||
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
|
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
|
||||||
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
|
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
|
||||||
if num_added_tokens != self.embed_config.tokens:
|
if num_added_tokens != self.embed_config.tokens:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different"
|
f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different"
|
||||||
" `placeholder_token` that is not already in the tokenizer."
|
f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert the initializer_token, placeholder_token to ids
|
# Convert the initializer_token, placeholder_token to ids
|
||||||
@@ -115,10 +118,10 @@ class Embedding:
|
|||||||
|
|
||||||
def _set_vec(self, new_vector, text_encoder_idx=0):
|
def _set_vec(self, new_vector, text_encoder_idx=0):
|
||||||
# shape is (1, 768) for SD 1.5 for 1 token
|
# shape is (1, 768) for SD 1.5 for 1 token
|
||||||
token_embeds = self.text_encoder_list[0].get_input_embeddings().weight.data
|
token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data
|
||||||
for i in range(new_vector.shape[0]):
|
for i in range(new_vector.shape[0]):
|
||||||
# apply the weights to the placeholder tokens while preserving gradient
|
# apply the weights to the placeholder tokens while preserving gradient
|
||||||
token_embeds[self.placeholder_token_ids[0][i]] = new_vector[i].clone()
|
token_embeds[self.placeholder_token_ids[text_encoder_idx][i]] = new_vector[i].clone()
|
||||||
|
|
||||||
# make setter and getter for vec
|
# make setter and getter for vec
|
||||||
@property
|
@property
|
||||||
@@ -249,30 +252,32 @@ class Embedding:
|
|||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
# textual inversion embeddings
|
|
||||||
if 'string_to_param' in data:
|
|
||||||
param_dict = data['string_to_param']
|
|
||||||
if hasattr(param_dict, '_parameters'):
|
|
||||||
param_dict = getattr(param_dict,
|
|
||||||
'_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
|
||||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
|
||||||
emb = next(iter(param_dict.items()))[1]
|
|
||||||
# diffuser concepts
|
|
||||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
|
||||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
|
||||||
|
|
||||||
emb = next(iter(data.values()))
|
|
||||||
if len(emb.shape) == 1:
|
|
||||||
emb = emb.unsqueeze(0)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
|
||||||
|
|
||||||
if 'step' in data:
|
|
||||||
self.step = int(data['step'])
|
|
||||||
|
|
||||||
if self.sd.is_xl:
|
if self.sd.is_xl:
|
||||||
self.vec = tensors['clip_l'].detach().to(device, dtype=torch.float32)
|
self.vec = tensors['clip_l'].detach().to(device, dtype=torch.float32)
|
||||||
self.vec2 = tensors['clip_g'].detach().to(device, dtype=torch.float32)
|
self.vec2 = tensors['clip_g'].detach().to(device, dtype=torch.float32)
|
||||||
|
if 'step' in data:
|
||||||
|
self.step = int(data['step'])
|
||||||
else:
|
else:
|
||||||
|
# textual inversion embeddings
|
||||||
|
if 'string_to_param' in data:
|
||||||
|
param_dict = data['string_to_param']
|
||||||
|
if hasattr(param_dict, '_parameters'):
|
||||||
|
param_dict = getattr(param_dict,
|
||||||
|
'_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||||
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||||
|
emb = next(iter(param_dict.items()))[1]
|
||||||
|
# diffuser concepts
|
||||||
|
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||||
|
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||||
|
|
||||||
|
emb = next(iter(data.values()))
|
||||||
|
if len(emb.shape) == 1:
|
||||||
|
emb = emb.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||||
|
|
||||||
|
if 'step' in data:
|
||||||
|
self.step = int(data['step'])
|
||||||
|
|
||||||
self.vec = emb.detach().to(device, dtype=torch.float32)
|
self.vec = emb.detach().to(device, dtype=torch.float32)
|
||||||
|
|||||||
Reference in New Issue
Block a user