mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-10 04:59:52 +00:00
Worked on reference slider script. It is working well currently. Still going to tune it a bit before a writeup though
This commit is contained in:
@@ -20,13 +20,18 @@ def flush():
|
||||
gc.collect()
|
||||
|
||||
|
||||
class DatasetConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.pair_folder: str = kwargs.get('pair_folder', None)
|
||||
self.network_weight: float = kwargs.get('network_weight', 1.0)
|
||||
self.target_class: str = kwargs.get('target_class', '')
|
||||
self.size: int = kwargs.get('size', 512)
|
||||
|
||||
|
||||
class ReferenceSliderConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.slider_pair_folder: str = kwargs.get('slider_pair_folder', None)
|
||||
self.resolutions: List[int] = kwargs.get('resolutions', [512])
|
||||
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
|
||||
self.target_class: int = kwargs.get('target_class', '')
|
||||
self.additional_losses: List[str] = kwargs.get('additional_losses', [])
|
||||
self.datasets: List[DatasetConfig] = [DatasetConfig(**d) for d in kwargs.get('datasets', [])]
|
||||
|
||||
|
||||
class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
@@ -46,12 +51,13 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
if self.data_loader is None:
|
||||
print(f"Loading datasets")
|
||||
datasets = []
|
||||
for res in self.slider_config.resolutions:
|
||||
print(f" - Dataset: {self.slider_config.slider_pair_folder}")
|
||||
for dataset in self.slider_config.datasets:
|
||||
print(f" - Dataset: {dataset.pair_folder}")
|
||||
config = {
|
||||
'path': self.slider_config.slider_pair_folder,
|
||||
'size': res,
|
||||
'default_prompt': self.slider_config.target_class
|
||||
'path': dataset.pair_folder,
|
||||
'size': dataset.size,
|
||||
'default_prompt': dataset.target_class,
|
||||
'network_weight': dataset.network_weight,
|
||||
}
|
||||
image_dataset = PairedImageDataset(config)
|
||||
datasets.append(image_dataset)
|
||||
@@ -78,7 +84,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
do_mirror_loss = 'mirror' in self.slider_config.additional_losses
|
||||
|
||||
with torch.no_grad():
|
||||
imgs, prompts = batch
|
||||
imgs, prompts, base_network_weight = batch
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype)
|
||||
# split batched images in half so left is negative and right is positive
|
||||
@@ -129,7 +135,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0)
|
||||
noise = torch.cat([noise_positive, noise_negative], dim=0)
|
||||
timesteps = torch.cat([timesteps, timesteps], dim=0)
|
||||
network_multiplier = [1.0, -1.0]
|
||||
network_multiplier = [base_network_weight * 1.0, base_network_weight * -1.0]
|
||||
|
||||
flush()
|
||||
|
||||
@@ -180,7 +186,8 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
noise_pred_pos, noise_pred_neg = torch.chunk(noise_pred, 2, dim=0)
|
||||
# mirror the negative
|
||||
noise_pred_neg = torch.flip(noise_pred_neg.clone(), dims=[3])
|
||||
loss_mirror = torch.nn.functional.mse_loss(noise_pred_pos.float(), noise_pred_neg.float(), reduction="none")
|
||||
loss_mirror = torch.nn.functional.mse_loss(noise_pred_pos.float(), noise_pred_neg.float(),
|
||||
reduction="none")
|
||||
loss_mirror = loss_mirror.mean([1, 2, 3])
|
||||
loss_mirror = loss_mirror.mean()
|
||||
loss_mirror_float = loss_mirror.item()
|
||||
|
||||
Reference in New Issue
Block a user