SDXL should be working, but I broke something where it is not converging.

This commit is contained in:
Jaret Burkett
2023-07-25 13:50:59 -06:00
parent 52f02d53f1
commit cb70c03273
11 changed files with 458 additions and 166 deletions

View File

@@ -17,6 +17,7 @@ class BaseJob:
raise ValueError('config is required') raise ValueError('config is required')
self.config = config['config'] self.config = config['config']
self.raw_config = config
self.job = config['job'] self.job = config['job']
self.name = self.get_conf('name', required=True) self.name = self.get_conf('name', required=True)
if 'meta' in config: if 'meta' in config:

View File

@@ -1,3 +1,4 @@
import json
import os import os
from jobs import BaseJob from jobs import BaseJob
@@ -6,7 +7,7 @@ from collections import OrderedDict
from typing import List from typing import List
from jobs.process import BaseExtractProcess, TrainFineTuneProcess from jobs.process import BaseExtractProcess, TrainFineTuneProcess
from datetime import datetime from datetime import datetime
import yaml
from toolkit.paths import REPOS_ROOT from toolkit.paths import REPOS_ROOT
import sys import sys
@@ -16,6 +17,7 @@ sys.path.append(REPOS_ROOT)
process_dict = { process_dict = {
'vae': 'TrainVAEProcess', 'vae': 'TrainVAEProcess',
'slider': 'TrainSliderProcess', 'slider': 'TrainSliderProcess',
'lora_hack': 'TrainLoRAHack',
} }
@@ -37,6 +39,13 @@ class TrainJob(BaseJob):
# loads the processes from the config # loads the processes from the config
self.load_processes(process_dict) self.load_processes(process_dict)
def save_training_config(self):
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
os.makedirs(self.training_folder, exist_ok=True)
save_dif = os.path.join(self.training_folder, f'run_config_{timestamp}.yaml')
with open(save_dif, 'w') as f:
yaml.dump(self.raw_config, f)
def run(self): def run(self):
super().run() super().run()
print("") print("")

View File

@@ -2,7 +2,6 @@ import time
from collections import OrderedDict from collections import OrderedDict
import os import os
from leco.train_util import predict_noise
from toolkit.kohya_model_util import load_vae from toolkit.kohya_model_util import load_vae
from toolkit.lora_special import LoRASpecialNetwork from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer from toolkit.optimizer import get_optimizer
@@ -12,7 +11,7 @@ import sys
sys.path.append(REPOS_ROOT) sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco')) sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from jobs.process import BaseTrainProcess from jobs.process import BaseTrainProcess
from toolkit.metadata import get_meta_for_safetensors from toolkit.metadata import get_meta_for_safetensors
@@ -24,6 +23,7 @@ from tqdm import tqdm
from leco import train_util, model_util from leco import train_util, model_util
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
def flush(): def flush():
@@ -35,15 +35,6 @@ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
class StableDiffusion:
def __init__(self, vae, tokenizer, text_encoder, unet, noise_scheduler):
self.vae = vae
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.unet = unet
self.noise_scheduler = noise_scheduler
class BaseSDTrainProcess(BaseTrainProcess): class BaseSDTrainProcess(BaseTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict): def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config) super().__init__(process_id, job, config)
@@ -80,26 +71,44 @@ class BaseSDTrainProcess(BaseTrainProcess):
original_device_dict = { original_device_dict = {
'vae': self.sd.vae.device, 'vae': self.sd.vae.device,
'unet': self.sd.unet.device, 'unet': self.sd.unet.device,
'text_encoder': self.sd.text_encoder.device,
# 'tokenizer': self.sd.tokenizer.device, # 'tokenizer': self.sd.tokenizer.device,
} }
# handle sdxl text encoder
if isinstance(self.sd.text_encoder, list):
for encoder, i in zip(self.sd.text_encoder, range(len(self.sd.text_encoder))):
original_device_dict[f'text_encoder_{i}'] = encoder.device
encoder.to(self.device_torch)
else:
original_device_dict['text_encoder'] = self.sd.text_encoder.device
self.sd.text_encoder.to(self.device_torch)
self.sd.vae.to(self.device_torch) self.sd.vae.to(self.device_torch)
self.sd.unet.to(self.device_torch) self.sd.unet.to(self.device_torch)
self.sd.text_encoder.to(self.device_torch) # self.sd.text_encoder.to(self.device_torch)
# self.sd.tokenizer.to(self.device_torch) # self.sd.tokenizer.to(self.device_torch)
# TODO add clip skip # TODO add clip skip
if self.sd.is_xl:
pipeline = StableDiffusionPipeline( pipeline = StableDiffusionXLPipeline(
vae=self.sd.vae, vae=self.sd.vae,
unet=self.sd.unet, unet=self.sd.unet,
text_encoder=self.sd.text_encoder, text_encoder=self.sd.text_encoder[0],
tokenizer=self.sd.tokenizer, text_encoder_2=self.sd.text_encoder[1],
scheduler=self.sd.noise_scheduler, tokenizer=self.sd.tokenizer[0],
safety_checker=None, tokenizer_2=self.sd.tokenizer[1],
feature_extractor=None, scheduler=self.sd.noise_scheduler,
requires_safety_checker=False, )
) else:
pipeline = StableDiffusionPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder,
tokenizer=self.sd.tokenizer,
scheduler=self.sd.noise_scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
# disable progress bar # disable progress bar
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
@@ -118,7 +127,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
'multiplier': self.network.multiplier, 'multiplier': self.network.multiplier,
}) })
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}", leave=False): for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}",
leave=False):
raw_prompt = self.sample_config.prompts[i] raw_prompt = self.sample_config.prompts[i]
neg = self.sample_config.neg neg = self.sample_config.neg
@@ -180,7 +190,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd.vae.to(original_device_dict['vae']) self.sd.vae.to(original_device_dict['vae'])
self.sd.unet.to(original_device_dict['unet']) self.sd.unet.to(original_device_dict['unet'])
self.sd.text_encoder.to(original_device_dict['text_encoder']) if isinstance(self.sd.text_encoder, list):
for encoder, i in zip(self.sd.text_encoder, range(len(self.sd.text_encoder))):
encoder.to(original_device_dict[f'text_encoder_{i}'])
else:
self.sd.text_encoder.to(original_device_dict['text_encoder'])
if self.network is not None: if self.network is not None:
self.network.train() self.network.train()
self.network.multiplier = start_multiplier self.network.multiplier = start_multiplier
@@ -267,23 +281,90 @@ class BaseSDTrainProcess(BaseTrainProcess):
# return loss # return loss
return 0.0 return 0.0
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 def get_time_ids_from_latents(self, latents):
def diffuse_some_steps( bs, ch, h, w = list(latents.shape)
height = h * VAE_SCALE_FACTOR
width = w * VAE_SCALE_FACTOR
dtype = get_torch_dtype(self.train_config.dtype)
if self.sd.is_xl:
prompt_ids = train_util.get_add_time_ids(
height,
width,
dynamic_crops=False, # look into this
dtype=dtype,
).to(self.device_torch, dtype=dtype)
return train_util.concat_embeddings(
prompt_ids, prompt_ids, bs
)
else:
return None
def predict_noise(
self, self,
latents: torch.FloatTensor, latents: torch.FloatTensor,
text_embeddings: torch.FloatTensor, text_embeddings: PromptEmbeds,
total_timesteps: int = 1000, timestep: int,
start_timesteps=0, guidance_scale=7.5,
guidance_rescale=0.7,
add_time_ids=None,
**kwargs, **kwargs,
): ):
if self.sd.is_xl:
for timestep in tqdm(self.sd.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False): if add_time_ids is None:
noise_pred = train_util.predict_noise( add_time_ids = self.get_time_ids_from_latents(latents)
self.sd.unet, self.sd.noise_scheduler, timestep, latents, text_embeddings, **kwargs # todo LECOs code looks like it is omitting noise_pred
noise_pred = train_util.predict_noise_xl(
self.sd.unet,
self.sd.noise_scheduler,
timestep,
latents,
text_embeddings.text_embeds,
text_embeddings.pooled_embeds,
add_time_ids,
guidance_scale=guidance_scale,
guidance_rescale=guidance_rescale
) )
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample
else:
noise_pred = train_util.predict_noise(
self.sd.unet,
self.sd.noise_scheduler,
timestep,
latents,
text_embeddings.text_embeds,
guidance_scale=guidance_scale
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample
return latents
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
def diffuse_some_steps(
self,
latents: torch.FloatTensor,
text_embeddings: PromptEmbeds,
total_timesteps: int = 1000,
start_timesteps=0,
guidance_scale=1,
add_time_ids=None,
**kwargs,
):
for timestep in tqdm(self.sd.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False):
latents = self.predict_noise(
latents,
text_embeddings,
timestep,
guidance_scale=guidance_scale,
add_time_ids=add_time_ids,
**kwargs,
)
# return latents_steps # return latents_steps
return latents return latents
@@ -296,20 +377,35 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype) dtype = get_torch_dtype(self.train_config.dtype)
tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models( if self.model_config.is_xl:
self.model_config.name_or_path, tokenizer, text_encoders, unet, noise_scheduler = model_util.load_models_xl(
scheduler_name=self.train_config.noise_scheduler, self.model_config.name_or_path,
v2=self.model_config.is_v2, scheduler_name=self.train_config.noise_scheduler,
v_pred=self.model_config.is_v_pred, weight_dtype=dtype,
) )
for text_encoder in text_encoders:
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
text_encoder = text_encoders
else:
tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models(
self.model_config.name_or_path,
scheduler_name=self.train_config.noise_scheduler,
v2=self.model_config.is_v2,
v_pred=self.model_config.is_v_pred,
)
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.eval()
# just for now or of we want to load a custom one # just for now or of we want to load a custom one
# put on cpu for now, we only need it when sampling # put on cpu for now, we only need it when sampling
vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype) vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype)
vae.eval() vae.eval()
self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler) self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler, is_xl=self.model_config.is_xl)
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.eval()
unet.to(self.device_torch, dtype=dtype) unet.to(self.device_torch, dtype=dtype)
if self.train_config.xformers: if self.train_config.xformers:
@@ -323,7 +419,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
unet=unet, unet=unet,
lora_dim=self.network_config.rank, lora_dim=self.network_config.rank,
multiplier=1.0, multiplier=1.0,
alpha=self.network_config.alpha alpha=self.network_config.alpha,
train_unet=self.train_config.train_unet,
train_text_encoder=self.train_config.train_text_encoder,
) )
self.network.force_to(self.device_torch, dtype=dtype) self.network.force_to(self.device_torch, dtype=dtype)
@@ -376,8 +474,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.hook_before_train_loop() self.hook_before_train_loop()
# sample first # sample first
self.print("Generating baseline samples before training") if self.train_config.skip_first_sample:
self.sample(0) self.print("Skipping first sample due to config setting")
else:
self.print("Generating baseline samples before training")
self.sample(0)
self.progress_bar = tqdm( self.progress_bar = tqdm(
total=self.train_config.steps, total=self.train_config.steps,

View File

@@ -0,0 +1,76 @@
# ref:
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
import time
from collections import OrderedDict
import os
from toolkit.config_modules import SliderConfig
from toolkit.paths import REPOS_ROOT
import sys
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import gc
import torch
from leco import train_util, model_util
from leco.prompt_util import PromptEmbedsCache
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
def flush():
torch.cuda.empty_cache()
gc.collect()
class LoRAHack:
def __init__(self, **kwargs):
self.type = kwargs.get('type', 'suppression')
class TrainLoRAHack(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.hack_config = LoRAHack(**self.get_conf('hack', {}))
def hook_before_train_loop(self):
# we don't need text encoder so move it to cpu
self.sd.text_encoder.to("cpu")
flush()
# end hook_before_train_loop
if self.hack_config.type == 'suppression':
# set all params to self.current_suppression
params = self.network.parameters()
for param in params:
# get random noise for each param
noise = torch.randn_like(param) - 0.5
# apply noise to param
param.data = noise * 0.001
def supress_loop(self):
dtype = get_torch_dtype(self.train_config.dtype)
loss_dict = OrderedDict(
{'sup': 0.0}
)
# increase noise
for param in self.network.parameters():
# get random noise for each param
noise = torch.randn_like(param) - 0.5
# apply noise to param
param.data = param.data + noise * 0.001
return loss_dict
def hook_train_loop(self):
if self.hack_config.type == 'suppression':
return self.supress_loop()
else:
raise NotImplementedError(f'unknown hack type: {self.hack_config.type}')
# end hook_train_loop

View File

@@ -3,19 +3,22 @@
import time import time
from collections import OrderedDict from collections import OrderedDict
import os import os
from typing import Optional
from toolkit.config_modules import SliderConfig from toolkit.config_modules import SliderConfig
from toolkit.paths import REPOS_ROOT from toolkit.paths import REPOS_ROOT
import sys import sys
from toolkit.stable_diffusion_model import PromptEmbeds
sys.path.append(REPOS_ROOT) sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco')) sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from toolkit.train_tools import get_torch_dtype, apply_noise_offset from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import gc import gc
from toolkit import train_tools
import torch import torch
from leco import train_util, model_util from leco import train_util, model_util
from leco.prompt_util import PromptEmbedsCache
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
@@ -29,7 +32,6 @@ def flush():
gc.collect() gc.collect()
class EncodedPromptPair: class EncodedPromptPair:
def __init__( def __init__(
self, self,
@@ -54,6 +56,19 @@ class EncodedPromptPair:
self.weight = weight self.weight = weight
class PromptEmbedsCache: # 使いまわしたいので
prompts: dict[str, PromptEmbeds] = {}
def __setitem__(self, __name: str, __value: PromptEmbeds) -> None:
self.prompts[__name] = __value
def __getitem__(self, __name: str) -> Optional[PromptEmbeds]:
if __name in self.prompts:
return self.prompts[__name]
else:
return None
class EncodedAnchor: class EncodedAnchor:
def __init__( def __init__(
self, self,
@@ -89,19 +104,17 @@ class TrainSliderProcess(BaseSDTrainProcess):
with torch.no_grad(): with torch.no_grad():
neutral = "" neutral = ""
for target in self.slider_config.targets: for target in self.slider_config.targets:
# build the cache
for prompt in [
target.target_class,
target.positive,
target.negative,
neutral # empty neutral
]:
if cache[prompt] is None:
cache[prompt] = self.sd.encode_prompt(prompt)
for resolution in self.slider_config.resolutions: for resolution in self.slider_config.resolutions:
width, height = resolution width, height = resolution
# build the cache
for prompt in [
target.target_class,
target.positive,
target.negative,
neutral # empty neutral
]:
if cache[prompt] == None:
cache[prompt] = train_util.encode_prompts(
self.sd.tokenizer, self.sd.text_encoder, [prompt]
)
only_erase = len(target.positive.strip()) == 0 only_erase = len(target.positive.strip()) == 0
only_enhance = len(target.negative.strip()) == 0 only_enhance = len(target.negative.strip()) == 0
@@ -184,9 +197,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
anchor.neg_prompt # empty neutral anchor.neg_prompt # empty neutral
]: ]:
if cache[prompt] == None: if cache[prompt] == None:
cache[prompt] = train_util.encode_prompts( cache[prompt] = self.sd.encode_prompt(prompt)
self.sd.tokenizer, self.sd.text_encoder, [prompt]
)
anchor_pairs += [ anchor_pairs += [
EncodedAnchor( EncodedAnchor(
@@ -198,7 +209,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
# move to cpu to save vram # move to cpu to save vram
# We don't need text encoder anymore, but keep it on cpu for sampling # We don't need text encoder anymore, but keep it on cpu for sampling
self.sd.text_encoder.to("cpu") # if text encoder is list
if isinstance(self.sd.text_encoder, list):
for encoder in self.sd.text_encoder:
encoder.to("cpu")
else:
self.sd.text_encoder.to("cpu")
self.prompt_cache = cache self.prompt_cache = cache
self.prompt_pairs = prompt_pairs self.prompt_pairs = prompt_pairs
self.anchor_pairs = anchor_pairs self.anchor_pairs = anchor_pairs
@@ -220,6 +236,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
negative = prompt_pair.negative negative = prompt_pair.negative
positive = prompt_pair.positive positive = prompt_pair.positive
weight = prompt_pair.weight weight = prompt_pair.weight
multiplier = prompt_pair.multiplier
unet = self.sd.unet unet = self.sd.unet
noise_scheduler = self.sd.noise_scheduler noise_scheduler = self.sd.noise_scheduler
@@ -227,8 +244,20 @@ class TrainSliderProcess(BaseSDTrainProcess):
lr_scheduler = self.lr_scheduler lr_scheduler = self.lr_scheduler
loss_function = torch.nn.MSELoss() loss_function = torch.nn.MSELoss()
def get_noise_pred(p, n):
return self.predict_noise(
latents=denoised_latents,
text_embeddings=train_tools.concat_prompt_embeddings(
p, # unconditional
n, # positive
self.train_config.batch_size,
),
timestep=current_timestep,
guidance_scale=1,
)
# set network multiplier # set network multiplier
self.network.multiplier = prompt_pair.multiplier self.network.multiplier = multiplier
with torch.no_grad(): with torch.no_grad():
self.sd.noise_scheduler.set_timesteps( self.sd.noise_scheduler.set_timesteps(
@@ -254,9 +283,10 @@ class TrainSliderProcess(BaseSDTrainProcess):
with self.network: with self.network:
assert self.network.is_active assert self.network.is_active
self.network.multiplier = multiplier
denoised_latents = self.diffuse_some_steps( denoised_latents = self.diffuse_some_steps(
latents, # pass simple noise latents latents, # pass simple noise latents
train_util.concat_embeddings( train_tools.concat_prompt_embeddings(
positive, # unconditional positive, # unconditional
target_class, # target target_class, # target
self.train_config.batch_size, self.train_config.batch_size,
@@ -272,43 +302,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
int(timesteps_to * 1000 / self.train_config.max_denoising_steps) int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
] ]
# with network: 0 weight LoRA is enabled outside "with network:" positive_latents = get_noise_pred(positive, negative)
positive_latents = train_util.predict_noise( # positive_latents
unet, neutral_latents = get_noise_pred(positive, neutral)
noise_scheduler,
current_timestep, unconditional_latents = get_noise_pred(positive, positive)
denoised_latents,
train_util.concat_embeddings(
positive, # unconditional
negative, # positive
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
neutral_latents = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
positive, # unconditional
neutral, # neutral
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
unconditional_latents = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
positive, # unconditional
positive, # unconditional
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
anchor_loss = None anchor_loss = None
if len(self.anchor_pairs) > 0: if len(self.anchor_pairs) > 0:
@@ -317,51 +315,19 @@ class TrainSliderProcess(BaseSDTrainProcess):
torch.randint(0, len(self.anchor_pairs), (1,)).item() torch.randint(0, len(self.anchor_pairs), (1,)).item()
] ]
with torch.no_grad(): with torch.no_grad():
anchor_target_noise = train_util.predict_noise( anchor_target_noise = get_noise_pred(anchor.prompt, anchor.neg_prompt)
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
anchor.prompt,
anchor.neg_prompt,
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
with self.network: with self.network:
# anchor whatever weight prompt pair is using # anchor whatever weight prompt pair is using
pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0 pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0
self.network.multiplier = anchor.multiplier * pos_nem_mult self.network.multiplier = anchor.multiplier * pos_nem_mult
anchor_pred_noise = train_util.predict_noise(
unet, anchor_pred_noise = get_noise_pred(anchor.prompt, anchor.neg_prompt)
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
anchor.prompt,
anchor.neg_prompt,
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
self.network.multiplier = prompt_pair.multiplier self.network.multiplier = prompt_pair.multiplier
with self.network: with self.network:
self.network.multiplier = prompt_pair.multiplier self.network.multiplier = prompt_pair.multiplier
target_latents = train_util.predict_noise( target_latents = get_noise_pred(positive, target_class)
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
positive, # unconditional
target_class, # target
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
# if self.logging_config.verbose: # if self.logging_config.verbose:
# self.print("target_latents:", target_latents[0, 0, :5, :5]) # self.print("target_latents:", target_latents[0, 0, :5, :5])

View File

@@ -6,3 +6,4 @@ from .BaseTrainProcess import BaseTrainProcess
from .TrainVAEProcess import TrainVAEProcess from .TrainVAEProcess import TrainVAEProcess
from .BaseMergeProcess import BaseMergeProcess from .BaseMergeProcess import BaseMergeProcess
from .TrainSliderProcess import TrainSliderProcess from .TrainSliderProcess import TrainSliderProcess
from .TrainLoRAHack import TrainLoRAHack

View File

@@ -9,4 +9,5 @@ accelerator
pyyaml pyyaml
oyaml oyaml
tensorboard tensorboard
kornia kornia
invisible-watermark

View File

@@ -50,12 +50,14 @@ class TrainConfig:
self.train_text_encoder = kwargs.get('train_text_encoder', True) self.train_text_encoder = kwargs.get('train_text_encoder', True)
self.noise_offset = kwargs.get('noise_offset', 0.0) self.noise_offset = kwargs.get('noise_offset', 0.0)
self.optimizer_params = kwargs.get('optimizer_params', {}) self.optimizer_params = kwargs.get('optimizer_params', {})
self.skip_first_sample = kwargs.get('skip_first_sample', False)
class ModelConfig: class ModelConfig:
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.name_or_path: str = kwargs.get('name_or_path', None) self.name_or_path: str = kwargs.get('name_or_path', None)
self.is_v2: bool = kwargs.get('is_v2', False) self.is_v2: bool = kwargs.get('is_v2', False)
self.is_xl: bool = kwargs.get('is_xl', False)
self.is_v_pred: bool = kwargs.get('is_v_pred', False) self.is_v_pred: bool = kwargs.get('is_v_pred', False)
if self.name_or_path is None: if self.name_or_path is None:

View File

@@ -1,8 +1,10 @@
import os import os
import sys import sys
from typing import List from typing import List, Optional, Dict, Type, Union
import torch import torch
from transformers import CLIPTextModel
from .paths import SD_SCRIPTS_ROOT from .paths import SD_SCRIPTS_ROOT
sys.path.append(SD_SCRIPTS_ROOT) sys.path.append(SD_SCRIPTS_ROOT)
@@ -14,26 +16,40 @@ class LoRASpecialNetwork(LoRANetwork):
_multiplier: float = 1.0 _multiplier: float = 1.0
is_active: bool = False is_active: bool = False
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
def __init__( def __init__(
self, self,
text_encoder, text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
unet, unet,
multiplier=1.0, multiplier: float = 1.0,
lora_dim=4, lora_dim: int = 4,
alpha=1, alpha: float = 1,
dropout=None, dropout: Optional[float] = None,
rank_dropout=None, rank_dropout: Optional[float] = None,
module_dropout=None, module_dropout: Optional[float] = None,
conv_lora_dim=None, conv_lora_dim: Optional[int] = None,
conv_alpha=None, conv_alpha: Optional[float] = None,
block_dims=None, block_dims: Optional[List[int]] = None,
block_alphas=None, block_alphas: Optional[List[float]] = None,
conv_block_dims=None, conv_block_dims: Optional[List[int]] = None,
conv_block_alphas=None, conv_block_alphas: Optional[List[float]] = None,
modules_dim=None, modules_dim: Optional[Dict[str, int]] = None,
modules_alpha=None, modules_alpha: Optional[Dict[str, int]] = None,
module_class=LoRAModule, module_class: Type[object] = LoRAModule,
varbose=False, varbose: Optional[bool] = False,
train_text_encoder: Optional[bool] = True,
train_unet: Optional[bool] = True,
) -> None: ) -> None:
""" """
LoRA network: すごく引数が多いが、パターンは以下の通り LoRA network: すごく引数が多いが、パターンは以下の通り
@@ -75,8 +91,21 @@ class LoRASpecialNetwork(LoRANetwork):
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
# create module instances # create module instances
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: def create_modules(
prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER is_unet: bool,
text_encoder_idx: Optional[int], # None, 1, 2
root_module: torch.nn.Module,
target_replace_modules: List[torch.nn.Module],
) -> List[LoRAModule]:
prefix = (
self.LORA_PREFIX_UNET
if is_unet
else (
self.LORA_PREFIX_TEXT_ENCODER
if text_encoder_idx is None
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
)
)
loras = [] loras = []
skipped = [] skipped = []
for name, module in root_module.named_modules(): for name, module in root_module.named_modules():
@@ -92,11 +121,14 @@ class LoRASpecialNetwork(LoRANetwork):
dim = None dim = None
alpha = None alpha = None
if modules_dim is not None: if modules_dim is not None:
# モジュール指定あり
if lora_name in modules_dim: if lora_name in modules_dim:
dim = modules_dim[lora_name] dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name] alpha = modules_alpha[lora_name]
elif is_unet and block_dims is not None: elif is_unet and block_dims is not None:
# U-Netでblock_dims指定あり
block_idx = get_block_index(lora_name) block_idx = get_block_index(lora_name)
if is_linear or is_conv2d_1x1: if is_linear or is_conv2d_1x1:
dim = block_dims[block_idx] dim = block_dims[block_idx]
@@ -105,6 +137,7 @@ class LoRASpecialNetwork(LoRANetwork):
dim = conv_block_dims[block_idx] dim = conv_block_dims[block_idx]
alpha = conv_block_alphas[block_idx] alpha = conv_block_alphas[block_idx]
else: else:
# 通常、すべて対象とする
if is_linear or is_conv2d_1x1: if is_linear or is_conv2d_1x1:
dim = self.lora_dim dim = self.lora_dim
alpha = self.alpha alpha = self.alpha
@@ -113,6 +146,7 @@ class LoRASpecialNetwork(LoRANetwork):
alpha = self.conv_alpha alpha = self.conv_alpha
if dim is None or dim == 0: if dim is None or dim == 0:
# skipした情報を出力
if is_linear or is_conv2d_1x1 or ( if is_linear or is_conv2d_1x1 or (
self.conv_lora_dim is not None or conv_block_dims is not None): self.conv_lora_dim is not None or conv_block_dims is not None):
skipped.append(lora_name) skipped.append(lora_name)
@@ -131,8 +165,25 @@ class LoRASpecialNetwork(LoRANetwork):
loras.append(lora) loras.append(lora)
return loras, skipped return loras, skipped
self.text_encoder_loras, skipped_te = create_modules(False, text_encoder, text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
# create LoRA for text encoder
# 毎回すべてのモジュールを作るのは無駄なので要検討
self.text_encoder_loras = []
skipped_te = []
if train_text_encoder:
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
index = i + 1
print(f"create LoRA for Text Encoder {index}:")
else:
index = None
print(f"create LoRA for Text Encoder:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder,
LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
@@ -140,7 +191,11 @@ class LoRASpecialNetwork(LoRANetwork):
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras, skipped_un = create_modules(True, unet, target_modules) if train_unet:
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
else:
self.unet_loras = []
skipped_un = []
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
skipped = skipped_te + skipped_un skipped = skipped_te + skipped_un
@@ -159,8 +214,7 @@ class LoRASpecialNetwork(LoRANetwork):
# assertion # assertion
names = set() names = set()
for lora in self.text_encoder_loras + self.unet_loras: for lora in self.text_encoder_loras + self.unet_loras:
# doesnt work on new diffusers. TODO make sure we are not missing something assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
# assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name) names.add(lora.lora_name)
def save_weights(self, file, dtype, metadata): def save_weights(self, file, dtype, metadata):

View File

@@ -0,0 +1,63 @@
from typing import Union
import sys
import os
from toolkit.paths import REPOS_ROOT
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from leco import train_util
import torch
class PromptEmbeds:
text_embeds: torch.FloatTensor
pooled_embeds: Union[torch.FloatTensor, None]
def __init__(self, args) -> None:
if isinstance(args, list) or isinstance(args, tuple):
# xl
self.text_embeds = args[0]
self.pooled_embeds = args[1]
else:
# sdv1.x, sdv2.x
self.text_embeds = args
self.pooled_embeds = None
class StableDiffusion:
def __init__(
self,
vae,
tokenizer,
text_encoder,
unet,
noise_scheduler,
is_xl=False
):
# text encoder has a list of 2 for xl
self.vae = vae
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.unet = unet
self.noise_scheduler = noise_scheduler
self.is_xl = is_xl
def encode_prompt(self, prompt, num_images_per_prompt=1) -> PromptEmbeds:
prompt = prompt
# if it is not a list, make it one
if not isinstance(prompt, list):
prompt = [prompt]
if self.is_xl:
return PromptEmbeds(
train_util.encode_prompts_xl(
self.tokenizer,
self.text_encoder,
prompt,
num_images_per_prompt=num_images_per_prompt,
)
)
else:
return PromptEmbeds(
train_util.encode_prompts(
self.tokenizer, self.text_encoder, prompt
)
)

View File

@@ -21,6 +21,8 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel
import torch import torch
import re import re
from toolkit.stable_diffusion_model import PromptEmbeds
SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120 SCHEDULER_LINEAR_END = 0.0120
SCHEDULER_TIMESTEPS = 1000 SCHEDULER_TIMESTEPS = 1000
@@ -377,3 +379,19 @@ def apply_noise_offset(noise, noise_offset):
return noise return noise
noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device) noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device)
return noise return noise
def concat_prompt_embeddings(
unconditional: PromptEmbeds,
conditional: PromptEmbeds,
n_imgs: int,
):
text_embeds = torch.cat(
[unconditional.text_embeds, conditional.text_embeds]
).repeat_interleave(n_imgs, dim=0)
pooled_embeds = None
if unconditional.pooled_embeds is not None and conditional.pooled_embeds is not None:
pooled_embeds = torch.cat(
[unconditional.pooled_embeds, conditional.pooled_embeds]
).repeat_interleave(n_imgs, dim=0)
return PromptEmbeds([text_embeds, pooled_embeds])