mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Added base for ultimate slider. WIP
This commit is contained in:
@@ -5,6 +5,8 @@ import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional, Union, List
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
|
||||
from toolkit.config_modules import ReferenceDatasetConfig
|
||||
from toolkit.data_loader import PairedImageDataset
|
||||
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
|
||||
@@ -21,29 +23,11 @@ def flush():
|
||||
gc.collect()
|
||||
|
||||
|
||||
class DatasetConfig:
|
||||
def __init__(self, **kwargs):
|
||||
# can pass with a side by side pait or a folder with pos and neg folder
|
||||
self.pair_folder: str = kwargs.get('pair_folder', None)
|
||||
self.pos_folder: str = kwargs.get('pos_folder', None)
|
||||
self.neg_folder: str = kwargs.get('neg_folder', None)
|
||||
|
||||
self.network_weight: float = float(kwargs.get('network_weight', 1.0))
|
||||
self.pos_weight: float = float(kwargs.get('pos_weight', self.network_weight))
|
||||
self.neg_weight: float = float(kwargs.get('neg_weight', self.network_weight))
|
||||
# make sure they are all absolute values no negatives
|
||||
self.pos_weight = abs(self.pos_weight)
|
||||
self.neg_weight = abs(self.neg_weight)
|
||||
|
||||
self.target_class: str = kwargs.get('target_class', '')
|
||||
self.size: int = kwargs.get('size', 512)
|
||||
|
||||
|
||||
class ReferenceSliderConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.additional_losses: List[str] = kwargs.get('additional_losses', [])
|
||||
self.weight_jitter: float = kwargs.get('weight_jitter', 0.0)
|
||||
self.datasets: List[DatasetConfig] = [DatasetConfig(**d) for d in kwargs.get('datasets', [])]
|
||||
self.datasets: List[ReferenceDatasetConfig] = [ReferenceDatasetConfig(**d) for d in kwargs.get('datasets', [])]
|
||||
|
||||
|
||||
class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
@@ -236,7 +220,6 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
loss.backward()
|
||||
flush()
|
||||
|
||||
|
||||
# apply gradients
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
Reference in New Issue
Block a user