Worked on reference slider script. It is working well currently. Still going to tune it a bit before a writeup though

This commit is contained in:
Jaret Burkett
2023-08-12 17:59:24 -06:00
parent fd95e7b60c
commit 196b693cf0
4 changed files with 54 additions and 33 deletions

View File

@@ -20,13 +20,18 @@ def flush():
gc.collect()
class DatasetConfig:
def __init__(self, **kwargs):
self.pair_folder: str = kwargs.get('pair_folder', None)
self.network_weight: float = kwargs.get('network_weight', 1.0)
self.target_class: str = kwargs.get('target_class', '')
self.size: int = kwargs.get('size', 512)
class ReferenceSliderConfig:
def __init__(self, **kwargs):
self.slider_pair_folder: str = kwargs.get('slider_pair_folder', None)
self.resolutions: List[int] = kwargs.get('resolutions', [512])
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
self.target_class: int = kwargs.get('target_class', '')
self.additional_losses: List[str] = kwargs.get('additional_losses', [])
self.datasets: List[DatasetConfig] = [DatasetConfig(**d) for d in kwargs.get('datasets', [])]
class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
@@ -46,12 +51,13 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
if self.data_loader is None:
print(f"Loading datasets")
datasets = []
for res in self.slider_config.resolutions:
print(f" - Dataset: {self.slider_config.slider_pair_folder}")
for dataset in self.slider_config.datasets:
print(f" - Dataset: {dataset.pair_folder}")
config = {
'path': self.slider_config.slider_pair_folder,
'size': res,
'default_prompt': self.slider_config.target_class
'path': dataset.pair_folder,
'size': dataset.size,
'default_prompt': dataset.target_class,
'network_weight': dataset.network_weight,
}
image_dataset = PairedImageDataset(config)
datasets.append(image_dataset)
@@ -78,7 +84,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
do_mirror_loss = 'mirror' in self.slider_config.additional_losses
with torch.no_grad():
imgs, prompts = batch
imgs, prompts, base_network_weight = batch
dtype = get_torch_dtype(self.train_config.dtype)
imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype)
# split batched images in half so left is negative and right is positive
@@ -129,7 +135,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0)
noise = torch.cat([noise_positive, noise_negative], dim=0)
timesteps = torch.cat([timesteps, timesteps], dim=0)
network_multiplier = [1.0, -1.0]
network_multiplier = [base_network_weight * 1.0, base_network_weight * -1.0]
flush()
@@ -180,7 +186,8 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
noise_pred_pos, noise_pred_neg = torch.chunk(noise_pred, 2, dim=0)
# mirror the negative
noise_pred_neg = torch.flip(noise_pred_neg.clone(), dims=[3])
loss_mirror = torch.nn.functional.mse_loss(noise_pred_pos.float(), noise_pred_neg.float(), reduction="none")
loss_mirror = torch.nn.functional.mse_loss(noise_pred_pos.float(), noise_pred_neg.float(),
reduction="none")
loss_mirror = loss_mirror.mean([1, 2, 3])
loss_mirror = loss_mirror.mean()
loss_mirror_float = loss_mirror.item()