Added this not that guidance. Added ability to replace prompts.

This commit is contained in:
Jaret Burkett
2024-02-28 20:10:14 -07:00
parent 561914d8e6
commit 337945de9a
7 changed files with 114 additions and 5 deletions

View File

@@ -229,9 +229,13 @@ class SDTrainer(BaseSDTrainProcess):
prior_mask_multiplier = None
target_mask_multiplier = None
dtype = get_torch_dtype(self.train_config.dtype)
has_mask = batch.mask_tensor is not None
with torch.no_grad():
loss_multiplier = torch.tensor(batch.loss_multiplier_list).to(self.device_torch, dtype=torch.float32)
if self.train_config.match_noise_norm:
# match the norm of the noise
noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True)
@@ -364,6 +368,8 @@ class SDTrainer(BaseSDTrainProcess):
# loss = loss + prior_loss
# loss = loss + prior_loss
loss = loss.mean([1, 2, 3])
# apply loss multiplier before prior loss
loss = loss * loss_multiplier
if prior_loss is not None:
loss = loss + prior_loss

View File

@@ -1480,10 +1480,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
batch = None
# if we are doing a reg step, always accumulate
if is_reg_step:
self.is_grad_accumulation_step = True
# setup accumulation
if self.train_config.gradient_accumulation_steps == -1:
# epoch is handling the accumulation, dont touch it

View File

@@ -496,6 +496,8 @@ class DatasetConfig:
self.clip_image_path: str = kwargs.get('clip_image_path', None) # depth maps, etc
self.clip_image_augmentations: List[dict] = kwargs.get('clip_image_augmentations', None)
self.clip_image_shuffle_augmentations: bool = kwargs.get('clip_image_shuffle_augmentations', False)
self.replacements: List[str] = kwargs.get('replacements', [])
self.loss_multiplier: float = kwargs.get('loss_multiplier', 1.0)
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:

View File

@@ -68,7 +68,7 @@ class FileItemDTO(
self.flip_x: bool = kwargs.get('flip_x', False)
self.flip_y: bool = kwargs.get('flip_x', False)
self.augments: List[str] = self.dataset_config.augments
self.loss_multiplier: float = self.dataset_config.loss_multiplier
self.network_weight: float = self.dataset_config.network_weight
self.is_reg = self.dataset_config.is_reg
@@ -124,6 +124,8 @@ class DataLoaderBatchDTO:
control_tensors.append(x.control_tensor)
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
self.loss_multiplier_list: List[float] = [x.loss_multiplier for x in self.file_items]
if any([x.clip_image_tensor is not None for x in self.file_items]):
# find one to use as a base
base_clip_image_tensor = None

View File

@@ -137,6 +137,13 @@ class CaptionMixin:
prompt = self.default_prompt
if hasattr(self, 'default_caption'):
prompt = self.default_caption
# handle replacements
replacement_list = self.dataset_config.replacements if isinstance(self.dataset_config.replacements, list) else []
for replacement in replacement_list:
from_string, to_string = replacement.split('|')
prompt = prompt.replace(from_string, to_string)
return prompt

View File

@@ -482,6 +482,87 @@ def get_guided_loss_polarity(
return loss
def get_guided_tnt(
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
**kwargs
):
dtype = get_torch_dtype(sd.torch_dtype)
device = sd.device_torch
with torch.no_grad():
dtype = get_torch_dtype(dtype)
noise = noise.to(device, dtype=dtype).detach()
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
conditional_noisy_latents = sd.add_noise(
conditional_latents,
noise,
timesteps
).detach()
unconditional_noisy_latents = sd.add_noise(
unconditional_latents,
noise,
timesteps
).detach()
# double up everything to run it through all at once
cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
# turn the LoRA network back on.
sd.unet.train()
if sd.network is not None:
cat_network_weight_list = [weight for weight in network_weight_list * 2]
sd.network.multiplier = cat_network_weight_list
sd.network.is_active = True
prediction = sd.predict_noise(
latents=cat_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
timestep=cat_timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
this_prediction, that_prediction = torch.chunk(prediction, 2, dim=0)
this_loss = torch.nn.functional.mse_loss(
this_prediction.float(),
noise.float(),
reduction="none"
)
that_loss = torch.nn.functional.mse_loss(
that_prediction.float(),
noise.float(),
reduction="none"
) * -1.0
loss = this_loss + that_loss
loss = loss.mean([1, 2, 3])
loss.backward()
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
loss.requires_grad_(True)
return loss
# this processes all guidance losses based on the batch information
def get_guidance_loss(
noisy_latents: torch.Tensor,
@@ -529,6 +610,20 @@ def get_guidance_loss(
sd,
**kwargs
)
elif guidance_type == "tnt":
assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance"
return get_guided_loss_polarity(
noisy_latents,
conditional_embeds,
match_adapter_assist,
network_weight_list,
timesteps,
pred_kwargs,
batch,
noise,
sd,
**kwargs
)
elif guidance_type == "targeted_polarity":
assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted polarity guidance"

View File

@@ -1273,6 +1273,7 @@ class StableDiffusion:
images = torch.stack(image_list)
latents = self.vae.encode(images).latent_dist.sample()
# latents = self.vae.encode(images, return_dict=False)[0]
latents = latents * self.vae.config['scaling_factor']
latents = latents.to(device, dtype=dtype)