mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-27 00:49:47 +00:00
Added inbuild plugins and made one for image referenced. WIP
This commit is contained in:
@@ -25,6 +25,7 @@ class ReferenceSliderConfig:
|
||||
self.resolutions: List[int] = kwargs.get('resolutions', [512])
|
||||
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
|
||||
self.target_class: int = kwargs.get('target_class', '')
|
||||
self.additional_losses: List[str] = kwargs.get('additional_losses', [])
|
||||
|
||||
|
||||
class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
@@ -73,6 +74,8 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
pass
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
do_mirror_loss = 'mirror' in self.slider_config.additional_losses
|
||||
|
||||
with torch.no_grad():
|
||||
imgs, prompts = batch
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
@@ -135,62 +138,89 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
noise_positive = self.sd.get_latent_noise(
|
||||
pixel_height=height,
|
||||
pixel_width=width,
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
if do_mirror_loss:
|
||||
# mirror the noise
|
||||
noise_negative = torch.flip(noise_positive.clone(), dims=[3])
|
||||
else:
|
||||
noise_negative = noise_positive.clone()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise, timesteps)
|
||||
noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise, timesteps)
|
||||
noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise_positive, timesteps)
|
||||
noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise_negative, timesteps)
|
||||
|
||||
noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0)
|
||||
noise = torch.cat([noise_positive, noise_negative], dim=0)
|
||||
timesteps = torch.cat([timesteps, timesteps], dim=0)
|
||||
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||
unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds])
|
||||
network_multiplier = [1.0, -1.0]
|
||||
|
||||
flush()
|
||||
|
||||
loss_float = None
|
||||
loss_slide_float = None
|
||||
loss_mirror_float = None
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
with self.network:
|
||||
assert self.network.is_active
|
||||
loss_list = []
|
||||
for noisy_latents, network_multiplier in zip(
|
||||
[noisy_positive_latents, noisy_negative_latents],
|
||||
[1.0, -1.0],
|
||||
):
|
||||
# do positive first
|
||||
self.network.multiplier = network_multiplier
|
||||
|
||||
noise_pred = get_noise_pred(
|
||||
unconditional_embeds,
|
||||
conditional_embeds,
|
||||
1,
|
||||
timesteps,
|
||||
noisy_latents
|
||||
)
|
||||
# do positive first
|
||||
self.network.multiplier = network_multiplier
|
||||
|
||||
if self.sd.is_v2: # check is vpred, don't want to track it down right now
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
noise_pred = get_noise_pred(
|
||||
unconditional_embeds,
|
||||
conditional_embeds,
|
||||
1,
|
||||
timesteps,
|
||||
noisy_latents
|
||||
)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
if self.sd.is_v2: # check is vpred, don't want to track it down right now
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
# todo add snr gamma here
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss = loss.mean()
|
||||
# back propagate loss to free ram
|
||||
loss.backward()
|
||||
loss_list.append(loss.item())
|
||||
# todo add snr gamma here
|
||||
|
||||
flush()
|
||||
loss = loss.mean()
|
||||
loss_slide_float = loss.item()
|
||||
|
||||
|
||||
if do_mirror_loss:
|
||||
noise_pred_pos, noise_pred_neg = torch.chunk(noise_pred, 2, dim=0)
|
||||
# mirror the negative
|
||||
noise_pred_neg = torch.flip(noise_pred_neg.clone(), dims=[3])
|
||||
loss_mirror = torch.nn.functional.mse_loss(noise_pred_pos.float(), noise_pred_neg.float(), reduction="none")
|
||||
loss_mirror = loss_mirror.mean([1, 2, 3])
|
||||
loss_mirror = loss_mirror.mean()
|
||||
loss_mirror_float = loss_mirror.item()
|
||||
loss += loss_mirror
|
||||
|
||||
loss_float = loss.item()
|
||||
|
||||
# back propagate loss to free ram
|
||||
loss.backward()
|
||||
|
||||
flush()
|
||||
|
||||
# apply gradients
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
loss_float = sum(loss_list) / len(loss_list)
|
||||
|
||||
# reset network
|
||||
self.network.multiplier = 1.0
|
||||
@@ -198,5 +228,9 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss_float},
|
||||
)
|
||||
|
||||
if do_mirror_loss:
|
||||
loss_dict['l/s'] = loss_slide_float
|
||||
loss_dict['l/m'] = loss_mirror_float
|
||||
return loss_dict
|
||||
# end hook_train_loop
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
---
|
||||
job: extension
|
||||
config:
|
||||
name: subject_turner_v1
|
||||
process:
|
||||
- type: 'image_reference_slider_trainer'
|
||||
training_folder: "/mnt/Train/out/LoRA"
|
||||
device: cuda:0
|
||||
# for tensorboard logging
|
||||
log_dir: "/home/jaret/Dev/.tensorboard"
|
||||
network:
|
||||
type: "lierla" # lierla is traditional LoRA that works everywhere, only linear layers
|
||||
rank: 16
|
||||
alpha: 8
|
||||
train:
|
||||
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
||||
steps: 1000
|
||||
lr: 5e-5
|
||||
train_unet: true
|
||||
gradient_checkpointing: true
|
||||
train_text_encoder: false
|
||||
optimizer: "lion8bit"
|
||||
lr_scheduler: "constant"
|
||||
max_denoising_steps: 1000
|
||||
batch_size: 1
|
||||
dtype: bf16
|
||||
xformers: true
|
||||
skip_first_sample: true
|
||||
noise_offset: 0.0 # I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX
|
||||
model:
|
||||
# name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/sdxl/sd_xl_base_0.9.safetensors"
|
||||
name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/Ostris_Real_v1.safetensors"
|
||||
# name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/sd_v1-5_vae.ckpt"
|
||||
is_v2: false # for v2 models
|
||||
is_xl: false # for SDXL models
|
||||
is_v_pred: false # for v-prediction models (most v2 models)
|
||||
save:
|
||||
dtype: float16 # precision to save
|
||||
save_every: 100 # save every this many steps
|
||||
max_step_saves_to_keep: 2 # only affects step counts
|
||||
sample:
|
||||
sampler: "ddpm" # must match train.noise_scheduler
|
||||
sample_every: 20 # sample every this many steps
|
||||
width: 512
|
||||
height: 512
|
||||
prompts:
|
||||
- "photo of a woman with red hair taking a selfie --m -3"
|
||||
- "photo of a woman with red hair taking a selfie --m -1"
|
||||
- "photo of a woman with red hair taking a selfie --m 1"
|
||||
- "photo of a woman with red hair taking a selfie --m 3"
|
||||
- "close up photo of a man smiling at the camera, in a tank top --m -3"
|
||||
- "close up photo of a man smiling at the camera, in a tank top--m -1"
|
||||
- "close up photo of a man smiling at the camera, in a tank top --m 1"
|
||||
- "close up photo of a man smiling at the camera, in a tank top --m 3"
|
||||
- "photo of a blonde woman smiling, barista --m -3"
|
||||
- "photo of a blonde woman smiling, barista --m -1"
|
||||
- "photo of a blonde woman smiling, barista --m 1"
|
||||
- "photo of a blonde woman smiling, barista --m 3"
|
||||
- "photo of a Christina Hendricks --m -1"
|
||||
- "photo of a Christina Hendricks --m -1"
|
||||
- "photo of a Christina Hendricks --m 1"
|
||||
- "photo of a Christina Hendricks --m 3"
|
||||
- "photo of a Christina Ricci --m -3"
|
||||
- "photo of a Christina Ricci --m -1"
|
||||
- "photo of a Christina Ricci --m 1"
|
||||
- "photo of a Christina Ricci --m 3"
|
||||
neg: "cartoon, fake, drawing, illustration, cgi, animated, anime"
|
||||
seed: 42
|
||||
walk_seed: false
|
||||
guidance_scale: 7
|
||||
sample_steps: 20
|
||||
network_multiplier: 1.0
|
||||
|
||||
logging:
|
||||
log_every: 10 # log every this many steps
|
||||
use_wandb: false # not supported yet
|
||||
verbose: false
|
||||
|
||||
slider:
|
||||
resolutions:
|
||||
- 512
|
||||
slider_pair_folder: "/mnt/Datasets/stable-diffusion/slider_reference/subject_turner"
|
||||
target_class: "photo of a person"
|
||||
|
||||
|
||||
meta:
|
||||
name: "[name]"
|
||||
version: '1.0'
|
||||
creator:
|
||||
name: Ostris - Jaret Burkett
|
||||
email: jaret@ostris.com
|
||||
website: https://ostris.com
|
||||
|
||||
Reference in New Issue
Block a user