Added inbuild plugins and made one for image referenced. WIP

This commit is contained in:
Jaret Burkett
2023-08-10 16:20:38 -06:00
parent 1a7e346b41
commit 67dfd9ced0
2 changed files with 156 additions and 30 deletions

View File

@@ -25,6 +25,7 @@ class ReferenceSliderConfig:
self.resolutions: List[int] = kwargs.get('resolutions', [512])
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
self.target_class: int = kwargs.get('target_class', '')
self.additional_losses: List[str] = kwargs.get('additional_losses', [])
class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
@@ -73,6 +74,8 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
pass
def hook_train_loop(self, batch):
do_mirror_loss = 'mirror' in self.slider_config.additional_losses
with torch.no_grad():
imgs, prompts = batch
dtype = get_torch_dtype(self.train_config.dtype)
@@ -135,62 +138,89 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
timesteps = timesteps.long()
# get noise
noise = self.sd.get_latent_noise(
noise_positive = self.sd.get_latent_noise(
pixel_height=height,
pixel_width=width,
batch_size=batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
if do_mirror_loss:
# mirror the noise
noise_negative = torch.flip(noise_positive.clone(), dims=[3])
else:
noise_negative = noise_positive.clone()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise, timesteps)
noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise, timesteps)
noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise_positive, timesteps)
noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise_negative, timesteps)
noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0)
noise = torch.cat([noise_positive, noise_negative], dim=0)
timesteps = torch.cat([timesteps, timesteps], dim=0)
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds])
network_multiplier = [1.0, -1.0]
flush()
loss_float = None
loss_slide_float = None
loss_mirror_float = None
self.optimizer.zero_grad()
with self.network:
assert self.network.is_active
loss_list = []
for noisy_latents, network_multiplier in zip(
[noisy_positive_latents, noisy_negative_latents],
[1.0, -1.0],
):
# do positive first
self.network.multiplier = network_multiplier
noise_pred = get_noise_pred(
unconditional_embeds,
conditional_embeds,
1,
timesteps,
noisy_latents
)
# do positive first
self.network.multiplier = network_multiplier
if self.sd.is_v2: # check is vpred, don't want to track it down right now
# v-parameterization training
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
noise_pred = get_noise_pred(
unconditional_embeds,
conditional_embeds,
1,
timesteps,
noisy_latents
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
if self.sd.is_v2: # check is vpred, don't want to track it down right now
# v-parameterization training
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
# todo add snr gamma here
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss = loss.mean()
# back propagate loss to free ram
loss.backward()
loss_list.append(loss.item())
# todo add snr gamma here
flush()
loss = loss.mean()
loss_slide_float = loss.item()
if do_mirror_loss:
noise_pred_pos, noise_pred_neg = torch.chunk(noise_pred, 2, dim=0)
# mirror the negative
noise_pred_neg = torch.flip(noise_pred_neg.clone(), dims=[3])
loss_mirror = torch.nn.functional.mse_loss(noise_pred_pos.float(), noise_pred_neg.float(), reduction="none")
loss_mirror = loss_mirror.mean([1, 2, 3])
loss_mirror = loss_mirror.mean()
loss_mirror_float = loss_mirror.item()
loss += loss_mirror
loss_float = loss.item()
# back propagate loss to free ram
loss.backward()
flush()
# apply gradients
optimizer.step()
lr_scheduler.step()
loss_float = sum(loss_list) / len(loss_list)
# reset network
self.network.multiplier = 1.0
@@ -198,5 +228,9 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
loss_dict = OrderedDict(
{'loss': loss_float},
)
if do_mirror_loss:
loss_dict['l/s'] = loss_slide_float
loss_dict['l/m'] = loss_mirror_float
return loss_dict
# end hook_train_loop

View File

@@ -0,0 +1,92 @@
---
job: extension
config:
name: subject_turner_v1
process:
- type: 'image_reference_slider_trainer'
training_folder: "/mnt/Train/out/LoRA"
device: cuda:0
# for tensorboard logging
log_dir: "/home/jaret/Dev/.tensorboard"
network:
type: "lierla" # lierla is traditional LoRA that works everywhere, only linear layers
rank: 16
alpha: 8
train:
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
steps: 1000
lr: 5e-5
train_unet: true
gradient_checkpointing: true
train_text_encoder: false
optimizer: "lion8bit"
lr_scheduler: "constant"
max_denoising_steps: 1000
batch_size: 1
dtype: bf16
xformers: true
skip_first_sample: true
noise_offset: 0.0 # I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX
model:
# name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/sdxl/sd_xl_base_0.9.safetensors"
name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/Ostris_Real_v1.safetensors"
# name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/sd_v1-5_vae.ckpt"
is_v2: false # for v2 models
is_xl: false # for SDXL models
is_v_pred: false # for v-prediction models (most v2 models)
save:
dtype: float16 # precision to save
save_every: 100 # save every this many steps
max_step_saves_to_keep: 2 # only affects step counts
sample:
sampler: "ddpm" # must match train.noise_scheduler
sample_every: 20 # sample every this many steps
width: 512
height: 512
prompts:
- "photo of a woman with red hair taking a selfie --m -3"
- "photo of a woman with red hair taking a selfie --m -1"
- "photo of a woman with red hair taking a selfie --m 1"
- "photo of a woman with red hair taking a selfie --m 3"
- "close up photo of a man smiling at the camera, in a tank top --m -3"
- "close up photo of a man smiling at the camera, in a tank top--m -1"
- "close up photo of a man smiling at the camera, in a tank top --m 1"
- "close up photo of a man smiling at the camera, in a tank top --m 3"
- "photo of a blonde woman smiling, barista --m -3"
- "photo of a blonde woman smiling, barista --m -1"
- "photo of a blonde woman smiling, barista --m 1"
- "photo of a blonde woman smiling, barista --m 3"
- "photo of a Christina Hendricks --m -1"
- "photo of a Christina Hendricks --m -1"
- "photo of a Christina Hendricks --m 1"
- "photo of a Christina Hendricks --m 3"
- "photo of a Christina Ricci --m -3"
- "photo of a Christina Ricci --m -1"
- "photo of a Christina Ricci --m 1"
- "photo of a Christina Ricci --m 3"
neg: "cartoon, fake, drawing, illustration, cgi, animated, anime"
seed: 42
walk_seed: false
guidance_scale: 7
sample_steps: 20
network_multiplier: 1.0
logging:
log_every: 10 # log every this many steps
use_wandb: false # not supported yet
verbose: false
slider:
resolutions:
- 512
slider_pair_folder: "/mnt/Datasets/stable-diffusion/slider_reference/subject_turner"
target_class: "photo of a person"
meta:
name: "[name]"
version: '1.0'
creator:
name: Ostris - Jaret Burkett
email: jaret@ostris.com
website: https://ostris.com