mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-06 11:10:10 +00:00
Added this not that guidance. Added ability to replace prompts.
This commit is contained in:
@@ -229,9 +229,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
prior_mask_multiplier = None
|
||||
target_mask_multiplier = None
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
has_mask = batch.mask_tensor is not None
|
||||
|
||||
with torch.no_grad():
|
||||
loss_multiplier = torch.tensor(batch.loss_multiplier_list).to(self.device_torch, dtype=torch.float32)
|
||||
|
||||
if self.train_config.match_noise_norm:
|
||||
# match the norm of the noise
|
||||
noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True)
|
||||
@@ -364,6 +368,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# loss = loss + prior_loss
|
||||
# loss = loss + prior_loss
|
||||
loss = loss.mean([1, 2, 3])
|
||||
# apply loss multiplier before prior loss
|
||||
loss = loss * loss_multiplier
|
||||
if prior_loss is not None:
|
||||
loss = loss + prior_loss
|
||||
|
||||
|
||||
@@ -1480,10 +1480,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else:
|
||||
batch = None
|
||||
|
||||
# if we are doing a reg step, always accumulate
|
||||
if is_reg_step:
|
||||
self.is_grad_accumulation_step = True
|
||||
|
||||
# setup accumulation
|
||||
if self.train_config.gradient_accumulation_steps == -1:
|
||||
# epoch is handling the accumulation, dont touch it
|
||||
|
||||
@@ -496,6 +496,8 @@ class DatasetConfig:
|
||||
self.clip_image_path: str = kwargs.get('clip_image_path', None) # depth maps, etc
|
||||
self.clip_image_augmentations: List[dict] = kwargs.get('clip_image_augmentations', None)
|
||||
self.clip_image_shuffle_augmentations: bool = kwargs.get('clip_image_shuffle_augmentations', False)
|
||||
self.replacements: List[str] = kwargs.get('replacements', [])
|
||||
self.loss_multiplier: float = kwargs.get('loss_multiplier', 1.0)
|
||||
|
||||
|
||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||
|
||||
@@ -68,7 +68,7 @@ class FileItemDTO(
|
||||
self.flip_x: bool = kwargs.get('flip_x', False)
|
||||
self.flip_y: bool = kwargs.get('flip_x', False)
|
||||
self.augments: List[str] = self.dataset_config.augments
|
||||
|
||||
self.loss_multiplier: float = self.dataset_config.loss_multiplier
|
||||
|
||||
self.network_weight: float = self.dataset_config.network_weight
|
||||
self.is_reg = self.dataset_config.is_reg
|
||||
@@ -124,6 +124,8 @@ class DataLoaderBatchDTO:
|
||||
control_tensors.append(x.control_tensor)
|
||||
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
|
||||
|
||||
self.loss_multiplier_list: List[float] = [x.loss_multiplier for x in self.file_items]
|
||||
|
||||
if any([x.clip_image_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_clip_image_tensor = None
|
||||
|
||||
@@ -137,6 +137,13 @@ class CaptionMixin:
|
||||
prompt = self.default_prompt
|
||||
if hasattr(self, 'default_caption'):
|
||||
prompt = self.default_caption
|
||||
|
||||
# handle replacements
|
||||
replacement_list = self.dataset_config.replacements if isinstance(self.dataset_config.replacements, list) else []
|
||||
for replacement in replacement_list:
|
||||
from_string, to_string = replacement.split('|')
|
||||
prompt = prompt.replace(from_string, to_string)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
|
||||
@@ -482,6 +482,87 @@ def get_guided_loss_polarity(
|
||||
return loss
|
||||
|
||||
|
||||
def get_guided_tnt(
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
sd: 'StableDiffusion',
|
||||
**kwargs
|
||||
):
|
||||
dtype = get_torch_dtype(sd.torch_dtype)
|
||||
device = sd.device_torch
|
||||
with torch.no_grad():
|
||||
dtype = get_torch_dtype(dtype)
|
||||
noise = noise.to(device, dtype=dtype).detach()
|
||||
|
||||
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
||||
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
||||
|
||||
conditional_noisy_latents = sd.add_noise(
|
||||
conditional_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
unconditional_noisy_latents = sd.add_noise(
|
||||
unconditional_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
# double up everything to run it through all at once
|
||||
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
|
||||
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
|
||||
|
||||
|
||||
# turn the LoRA network back on.
|
||||
sd.unet.train()
|
||||
if sd.network is not None:
|
||||
cat_network_weight_list = [weight for weight in network_weight_list * 2]
|
||||
sd.network.multiplier = cat_network_weight_list
|
||||
sd.network.is_active = True
|
||||
|
||||
prediction = sd.predict_noise(
|
||||
latents=cat_latents.to(device, dtype=dtype).detach(),
|
||||
conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
|
||||
timestep=cat_timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
this_prediction, that_prediction = torch.chunk(prediction, 2, dim=0)
|
||||
|
||||
this_loss = torch.nn.functional.mse_loss(
|
||||
this_prediction.float(),
|
||||
noise.float(),
|
||||
reduction="none"
|
||||
)
|
||||
|
||||
that_loss = torch.nn.functional.mse_loss(
|
||||
that_prediction.float(),
|
||||
noise.float(),
|
||||
reduction="none"
|
||||
) * -1.0
|
||||
|
||||
loss = this_loss + that_loss
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss.backward()
|
||||
|
||||
# detach it so parent class can run backward on no grads without throwing error
|
||||
loss = loss.detach()
|
||||
loss.requires_grad_(True)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
# this processes all guidance losses based on the batch information
|
||||
def get_guidance_loss(
|
||||
noisy_latents: torch.Tensor,
|
||||
@@ -529,6 +610,20 @@ def get_guidance_loss(
|
||||
sd,
|
||||
**kwargs
|
||||
)
|
||||
elif guidance_type == "tnt":
|
||||
assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance"
|
||||
return get_guided_loss_polarity(
|
||||
noisy_latents,
|
||||
conditional_embeds,
|
||||
match_adapter_assist,
|
||||
network_weight_list,
|
||||
timesteps,
|
||||
pred_kwargs,
|
||||
batch,
|
||||
noise,
|
||||
sd,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
elif guidance_type == "targeted_polarity":
|
||||
assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted polarity guidance"
|
||||
|
||||
@@ -1273,6 +1273,7 @@ class StableDiffusion:
|
||||
|
||||
images = torch.stack(image_list)
|
||||
latents = self.vae.encode(images).latent_dist.sample()
|
||||
# latents = self.vae.encode(images, return_dict=False)[0]
|
||||
latents = latents * self.vae.config['scaling_factor']
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user