Added base for ultimate slider. WIP

This commit is contained in:
Jaret Burkett
2023-08-19 15:35:24 -06:00
parent c6675e2801
commit b77b9acc0b
6 changed files with 568 additions and 36 deletions

View File

@@ -5,6 +5,8 @@ import os
from contextlib import nullcontext
from typing import Optional, Union, List
from torch.utils.data import ConcatDataset, DataLoader
from toolkit.config_modules import ReferenceDatasetConfig
from toolkit.data_loader import PairedImageDataset
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
@@ -21,29 +23,11 @@ def flush():
gc.collect()
class DatasetConfig:
def __init__(self, **kwargs):
# can pass with a side by side pait or a folder with pos and neg folder
self.pair_folder: str = kwargs.get('pair_folder', None)
self.pos_folder: str = kwargs.get('pos_folder', None)
self.neg_folder: str = kwargs.get('neg_folder', None)
self.network_weight: float = float(kwargs.get('network_weight', 1.0))
self.pos_weight: float = float(kwargs.get('pos_weight', self.network_weight))
self.neg_weight: float = float(kwargs.get('neg_weight', self.network_weight))
# make sure they are all absolute values no negatives
self.pos_weight = abs(self.pos_weight)
self.neg_weight = abs(self.neg_weight)
self.target_class: str = kwargs.get('target_class', '')
self.size: int = kwargs.get('size', 512)
class ReferenceSliderConfig:
def __init__(self, **kwargs):
self.additional_losses: List[str] = kwargs.get('additional_losses', [])
self.weight_jitter: float = kwargs.get('weight_jitter', 0.0)
self.datasets: List[DatasetConfig] = [DatasetConfig(**d) for d in kwargs.get('datasets', [])]
self.datasets: List[ReferenceDatasetConfig] = [ReferenceDatasetConfig(**d) for d in kwargs.get('datasets', [])]
class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
@@ -236,7 +220,6 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
loss.backward()
flush()
# apply gradients
optimizer.step()
lr_scheduler.step()