mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-20 12:23:57 +00:00
Worked on reference slider script. It is working well currently. Still going to tune it a bit before a writeup though
This commit is contained in:
@@ -20,13 +20,18 @@ def flush():
|
||||
gc.collect()
|
||||
|
||||
|
||||
class DatasetConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.pair_folder: str = kwargs.get('pair_folder', None)
|
||||
self.network_weight: float = kwargs.get('network_weight', 1.0)
|
||||
self.target_class: str = kwargs.get('target_class', '')
|
||||
self.size: int = kwargs.get('size', 512)
|
||||
|
||||
|
||||
class ReferenceSliderConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.slider_pair_folder: str = kwargs.get('slider_pair_folder', None)
|
||||
self.resolutions: List[int] = kwargs.get('resolutions', [512])
|
||||
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
|
||||
self.target_class: int = kwargs.get('target_class', '')
|
||||
self.additional_losses: List[str] = kwargs.get('additional_losses', [])
|
||||
self.datasets: List[DatasetConfig] = [DatasetConfig(**d) for d in kwargs.get('datasets', [])]
|
||||
|
||||
|
||||
class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
@@ -46,12 +51,13 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
if self.data_loader is None:
|
||||
print(f"Loading datasets")
|
||||
datasets = []
|
||||
for res in self.slider_config.resolutions:
|
||||
print(f" - Dataset: {self.slider_config.slider_pair_folder}")
|
||||
for dataset in self.slider_config.datasets:
|
||||
print(f" - Dataset: {dataset.pair_folder}")
|
||||
config = {
|
||||
'path': self.slider_config.slider_pair_folder,
|
||||
'size': res,
|
||||
'default_prompt': self.slider_config.target_class
|
||||
'path': dataset.pair_folder,
|
||||
'size': dataset.size,
|
||||
'default_prompt': dataset.target_class,
|
||||
'network_weight': dataset.network_weight,
|
||||
}
|
||||
image_dataset = PairedImageDataset(config)
|
||||
datasets.append(image_dataset)
|
||||
@@ -78,7 +84,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
do_mirror_loss = 'mirror' in self.slider_config.additional_losses
|
||||
|
||||
with torch.no_grad():
|
||||
imgs, prompts = batch
|
||||
imgs, prompts, base_network_weight = batch
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype)
|
||||
# split batched images in half so left is negative and right is positive
|
||||
@@ -129,7 +135,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0)
|
||||
noise = torch.cat([noise_positive, noise_negative], dim=0)
|
||||
timesteps = torch.cat([timesteps, timesteps], dim=0)
|
||||
network_multiplier = [1.0, -1.0]
|
||||
network_multiplier = [base_network_weight * 1.0, base_network_weight * -1.0]
|
||||
|
||||
flush()
|
||||
|
||||
@@ -180,7 +186,8 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
noise_pred_pos, noise_pred_neg = torch.chunk(noise_pred, 2, dim=0)
|
||||
# mirror the negative
|
||||
noise_pred_neg = torch.flip(noise_pred_neg.clone(), dims=[3])
|
||||
loss_mirror = torch.nn.functional.mse_loss(noise_pred_pos.float(), noise_pred_neg.float(), reduction="none")
|
||||
loss_mirror = torch.nn.functional.mse_loss(noise_pred_pos.float(), noise_pred_neg.float(),
|
||||
reduction="none")
|
||||
loss_mirror = loss_mirror.mean([1, 2, 3])
|
||||
loss_mirror = loss_mirror.mean()
|
||||
loss_mirror_float = loss_mirror.item()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
---
|
||||
job: extension
|
||||
config:
|
||||
name: subject_turner_v1
|
||||
name: example_name
|
||||
process:
|
||||
- type: 'image_reference_slider_trainer'
|
||||
training_folder: "/mnt/Train/out/LoRA"
|
||||
@@ -10,10 +10,8 @@ config:
|
||||
log_dir: "/home/jaret/Dev/.tensorboard"
|
||||
network:
|
||||
type: "lora"
|
||||
linear: 64
|
||||
linear_alpha: 32
|
||||
conv: 32
|
||||
conv_alpha: 16
|
||||
linear: 8
|
||||
linear_alpha: 8
|
||||
train:
|
||||
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
||||
steps: 5000
|
||||
@@ -30,11 +28,9 @@ config:
|
||||
dtype: bf16
|
||||
xformers: true
|
||||
skip_first_sample: true
|
||||
noise_offset: 0.0 # I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX
|
||||
noise_offset: 0.0
|
||||
model:
|
||||
# name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/sdxl/sd_xl_base_0.9.safetensors"
|
||||
name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/Ostris_Real_v1.safetensors"
|
||||
# name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/sd_v1-5_vae.ckpt"
|
||||
name_or_path: "/path/to/model.safetensors"
|
||||
is_v2: false # for v2 models
|
||||
is_xl: false # for SDXL models
|
||||
is_v_pred: false # for v-prediction models (most v2 models)
|
||||
@@ -81,18 +77,31 @@ config:
|
||||
verbose: false
|
||||
|
||||
slider:
|
||||
resolutions:
|
||||
- 512
|
||||
slider_pair_folder: "/mnt/Datasets/stable-diffusion/slider_reference/subject_turner"
|
||||
target_class: "photo of a person"
|
||||
# additional_losses:
|
||||
# - "mirror"
|
||||
datasets:
|
||||
- pair_folder: "/path/to/folder/side/by/side/images"
|
||||
network_weight: 2.0
|
||||
target_class: "" # only used as default if caption txt are not present
|
||||
size: 512
|
||||
- pair_folder: "/path/to/folder/side/by/side/images"
|
||||
network_weight: 4.0
|
||||
target_class: "" # only used as default if caption txt are not present
|
||||
size: 512
|
||||
|
||||
|
||||
# you can put any information you want here, and it will be saved in the model
|
||||
# the below is an example. I recommend doing trigger words at a minimum
|
||||
# in the metadata. The software will include this plus some other information
|
||||
meta:
|
||||
name: "[name]"
|
||||
version: '1.0'
|
||||
name: "[name]" # [name] gets replaced with the name above
|
||||
description: A short description of your model
|
||||
trigger_words:
|
||||
- put
|
||||
- trigger
|
||||
- words
|
||||
- here
|
||||
version: '0.1'
|
||||
creator:
|
||||
name: Ostris - Jaret Burkett
|
||||
email: jaret@ostris.com
|
||||
website: https://ostris.com
|
||||
name: Your Name
|
||||
email: your@email.com
|
||||
website: https://yourwebsite.com
|
||||
any: All meta data above is arbitrary, it can be whatever you want.
|
||||
4
toolkit/basic.py
Normal file
4
toolkit/basic.py
Normal file
@@ -0,0 +1,4 @@
|
||||
|
||||
|
||||
def value_map(inputs, min_in, max_in, min_out, max_out):
|
||||
return (inputs - min_in) * (max_out - min_out) / (max_in - min_in) + min_out
|
||||
@@ -149,6 +149,7 @@ class PairedImageDataset(Dataset):
|
||||
self.size = self.get_config('size', 512)
|
||||
self.path = self.get_config('path', required=True)
|
||||
self.default_prompt = self.get_config('default_prompt', '')
|
||||
self.network_weight = self.get_config('network_weight', 1.0)
|
||||
self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if
|
||||
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
|
||||
print(f" - Found {len(self.file_list)} images")
|
||||
@@ -200,5 +201,5 @@ class PairedImageDataset(Dataset):
|
||||
img = img.resize((width, height), Image.BICUBIC)
|
||||
img = self.transform(img)
|
||||
|
||||
return img, prompt
|
||||
return img, prompt, self.network_weight
|
||||
|
||||
|
||||
Reference in New Issue
Block a user