Worked on reference slider script. It is working well currently. Still going to tune it a bit before a writeup though

This commit is contained in:
Jaret Burkett
2023-08-12 17:59:24 -06:00
parent fd95e7b60c
commit 196b693cf0
4 changed files with 54 additions and 33 deletions

View File

@@ -20,13 +20,18 @@ def flush():
gc.collect()
class DatasetConfig:
def __init__(self, **kwargs):
self.pair_folder: str = kwargs.get('pair_folder', None)
self.network_weight: float = kwargs.get('network_weight', 1.0)
self.target_class: str = kwargs.get('target_class', '')
self.size: int = kwargs.get('size', 512)
class ReferenceSliderConfig:
def __init__(self, **kwargs):
self.slider_pair_folder: str = kwargs.get('slider_pair_folder', None)
self.resolutions: List[int] = kwargs.get('resolutions', [512])
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
self.target_class: int = kwargs.get('target_class', '')
self.additional_losses: List[str] = kwargs.get('additional_losses', [])
self.datasets: List[DatasetConfig] = [DatasetConfig(**d) for d in kwargs.get('datasets', [])]
class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
@@ -46,12 +51,13 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
if self.data_loader is None:
print(f"Loading datasets")
datasets = []
for res in self.slider_config.resolutions:
print(f" - Dataset: {self.slider_config.slider_pair_folder}")
for dataset in self.slider_config.datasets:
print(f" - Dataset: {dataset.pair_folder}")
config = {
'path': self.slider_config.slider_pair_folder,
'size': res,
'default_prompt': self.slider_config.target_class
'path': dataset.pair_folder,
'size': dataset.size,
'default_prompt': dataset.target_class,
'network_weight': dataset.network_weight,
}
image_dataset = PairedImageDataset(config)
datasets.append(image_dataset)
@@ -78,7 +84,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
do_mirror_loss = 'mirror' in self.slider_config.additional_losses
with torch.no_grad():
imgs, prompts = batch
imgs, prompts, base_network_weight = batch
dtype = get_torch_dtype(self.train_config.dtype)
imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype)
# split batched images in half so left is negative and right is positive
@@ -129,7 +135,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0)
noise = torch.cat([noise_positive, noise_negative], dim=0)
timesteps = torch.cat([timesteps, timesteps], dim=0)
network_multiplier = [1.0, -1.0]
network_multiplier = [base_network_weight * 1.0, base_network_weight * -1.0]
flush()
@@ -180,7 +186,8 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
noise_pred_pos, noise_pred_neg = torch.chunk(noise_pred, 2, dim=0)
# mirror the negative
noise_pred_neg = torch.flip(noise_pred_neg.clone(), dims=[3])
loss_mirror = torch.nn.functional.mse_loss(noise_pred_pos.float(), noise_pred_neg.float(), reduction="none")
loss_mirror = torch.nn.functional.mse_loss(noise_pred_pos.float(), noise_pred_neg.float(),
reduction="none")
loss_mirror = loss_mirror.mean([1, 2, 3])
loss_mirror = loss_mirror.mean()
loss_mirror_float = loss_mirror.item()

View File

@@ -1,7 +1,7 @@
---
job: extension
config:
name: subject_turner_v1
name: example_name
process:
- type: 'image_reference_slider_trainer'
training_folder: "/mnt/Train/out/LoRA"
@@ -10,10 +10,8 @@ config:
log_dir: "/home/jaret/Dev/.tensorboard"
network:
type: "lora"
linear: 64
linear_alpha: 32
conv: 32
conv_alpha: 16
linear: 8
linear_alpha: 8
train:
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
steps: 5000
@@ -30,11 +28,9 @@ config:
dtype: bf16
xformers: true
skip_first_sample: true
noise_offset: 0.0 # I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX
noise_offset: 0.0
model:
# name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/sdxl/sd_xl_base_0.9.safetensors"
name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/Ostris_Real_v1.safetensors"
# name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/sd_v1-5_vae.ckpt"
name_or_path: "/path/to/model.safetensors"
is_v2: false # for v2 models
is_xl: false # for SDXL models
is_v_pred: false # for v-prediction models (most v2 models)
@@ -81,18 +77,31 @@ config:
verbose: false
slider:
resolutions:
- 512
slider_pair_folder: "/mnt/Datasets/stable-diffusion/slider_reference/subject_turner"
target_class: "photo of a person"
# additional_losses:
# - "mirror"
datasets:
- pair_folder: "/path/to/folder/side/by/side/images"
network_weight: 2.0
target_class: "" # only used as default if caption txt are not present
size: 512
- pair_folder: "/path/to/folder/side/by/side/images"
network_weight: 4.0
target_class: "" # only used as default if caption txt are not present
size: 512
# you can put any information you want here, and it will be saved in the model
# the below is an example. I recommend doing trigger words at a minimum
# in the metadata. The software will include this plus some other information
meta:
name: "[name]"
version: '1.0'
name: "[name]" # [name] gets replaced with the name above
description: A short description of your model
trigger_words:
- put
- trigger
- words
- here
version: '0.1'
creator:
name: Ostris - Jaret Burkett
email: jaret@ostris.com
website: https://ostris.com
name: Your Name
email: your@email.com
website: https://yourwebsite.com
any: All meta data above is arbitrary, it can be whatever you want.

4
toolkit/basic.py Normal file
View File

@@ -0,0 +1,4 @@
def value_map(inputs, min_in, max_in, min_out, max_out):
return (inputs - min_in) * (max_out - min_out) / (max_in - min_in) + min_out

View File

@@ -149,6 +149,7 @@ class PairedImageDataset(Dataset):
self.size = self.get_config('size', 512)
self.path = self.get_config('path', required=True)
self.default_prompt = self.get_config('default_prompt', '')
self.network_weight = self.get_config('network_weight', 1.0)
self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
print(f" - Found {len(self.file_list)} images")
@@ -200,5 +201,5 @@ class PairedImageDataset(Dataset):
img = img.resize((width, height), Image.BICUBIC)
img = self.transform(img)
return img, prompt
return img, prompt, self.network_weight