Base for loopback lora training setup, still working on proper sliders

This commit is contained in:
Jaret Burkett
2023-07-21 18:26:02 -06:00
parent aa13251877
commit ddcd9069e1
8 changed files with 1097 additions and 13 deletions

View File

@@ -1,10 +1,14 @@
import os
from collections import OrderedDict from collections import OrderedDict
from typing import ForwardRef
from jobs.process.BaseProcess import BaseProcess from jobs.process.BaseProcess import BaseProcess
class BaseTrainProcess(BaseProcess): class BaseTrainProcess(BaseProcess):
process_id: int process_id: int
config: OrderedDict config: OrderedDict
progress_bar: ForwardRef('tqdm') = None
def __init__( def __init__(
self, self,
@@ -13,8 +17,23 @@ class BaseTrainProcess(BaseProcess):
config: OrderedDict config: OrderedDict
): ):
super().__init__(process_id, job, config) super().__init__(process_id, job, config)
self.progress_bar = None
self.writer = self.job.writer
self.training_folder = self.get_conf('training_folder', self.job.training_folder)
self.save_root = os.path.join(self.training_folder, self.job.name)
self.step = 0
self.first_step = 0
def run(self): def run(self):
super().run()
# implement in child class # implement in child class
# be sure to call super().run() first # be sure to call super().run() first
pass pass
# def print(self, message, **kwargs):
def print(self, *args):
if self.progress_bar is not None:
self.progress_bar.write(' '.join(map(str, args)))
self.progress_bar.update()
else:
print(*args)

View File

@@ -0,0 +1,609 @@
# ref:
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
import time
from collections import OrderedDict
import os
from toolkit.kohya_model_util import load_vae
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.paths import REPOS_ROOT
import sys
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from diffusers import StableDiffusionPipeline
from jobs.process import BaseTrainProcess
from toolkit.metadata import get_meta_for_safetensors
from toolkit.train_tools import get_torch_dtype
import gc
import torch
from tqdm import tqdm
from toolkit.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV, TRAINING_METHODS
from leco import train_util, model_util
from leco.prompt_util import PromptEmbedsCache, PromptEmbedsPair, ACTION_TYPES
from leco import debug_util
def flush():
torch.cuda.empty_cache()
gc.collect()
class StableDiffusion:
def __init__(self, vae, tokenizer, text_encoder, unet, noise_scheduler):
self.vae = vae
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.unet = unet
self.noise_scheduler = noise_scheduler
class SaveConfig:
def __init__(self, **kwargs):
self.save_every: int = kwargs.get('save_every', 1000)
self.dtype: str = kwargs.get('save_dtype', 'float16')
class LogingConfig:
def __init__(self, **kwargs):
self.log_every: int = kwargs.get('log_every', 100)
self.verbose: bool = kwargs.get('verbose', False)
self.use_wandb: bool = kwargs.get('use_wandb', False)
class SampleConfig:
def __init__(self, **kwargs):
self.sample_every: int = kwargs.get('sample_every', 100)
self.width: int = kwargs.get('width', 512)
self.height: int = kwargs.get('height', 512)
self.prompts: list[str] = kwargs.get('prompts', [])
self.neg = kwargs.get('neg', False)
self.seed = kwargs.get('seed', 0)
self.walk_seed = kwargs.get('walk_seed', False)
self.guidance_scale = kwargs.get('guidance_scale', 7)
self.sample_steps = kwargs.get('sample_steps', 20)
class NetworkConfig:
def __init__(self, **kwargs):
self.type: str = kwargs.get('type', 'lierla')
self.rank: int = kwargs.get('rank', 4)
self.alpha: float = kwargs.get('alpha', 1.0)
class TrainConfig:
def __init__(self, **kwargs):
self.noise_scheduler: 'model_util.AVAILABLE_SCHEDULERS' = kwargs.get('noise_scheduler', 'ddpm')
self.steps: int = kwargs.get('steps', 1000)
self.lr = kwargs.get('lr', 1e-6)
self.optimizer = kwargs.get('optimizer', 'adamw')
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 50)
self.batch_size: int = kwargs.get('batch_size', 1)
self.dtype: str = kwargs.get('dtype', 'fp32')
self.xformers = kwargs.get('xformers', False)
self.train_unet = kwargs.get('train_unet', True)
self.train_text_encoder = kwargs.get('train_text_encoder', True)
class ModelConfig:
def __init__(self, **kwargs):
self.name_or_path: str = kwargs.get('name_or_path', None)
self.is_v2: bool = kwargs.get('is_v2', False)
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
if self.name_or_path is None:
raise ValueError('name_or_path must be specified')
class PromptSettingsOld:
def __init__(self, **kwargs):
self.target: str = kwargs.get('target', None)
self.positive = kwargs.get('positive', None) # if None, target will be used
self.unconditional = kwargs.get('unconditional', "") # default is ""
self.neutral = kwargs.get('neutral', None) # if None, unconditional will be used
self.action: ACTION_TYPES = kwargs.get('action', "erase") # default is "erase"
self.guidance_scale: float = kwargs.get('guidance_scale', 1.0) # default is 1.0
self.resolution: int = kwargs.get('resolution', 512) # default is 512
self.dynamic_resolution: bool = kwargs.get('dynamic_resolution', False) # default is False
self.batch_size: int = kwargs.get('batch_size', 1) # default is 1
self.dynamic_crops: bool = kwargs.get('dynamic_crops', False) # default is False. only used when model is XL
class TrainSliderProcess(BaseTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.step_num = 0
self.start_step = 0
self.device = self.get_conf('device', self.job.device)
self.device_torch = torch.device(self.device)
self.network_config = NetworkConfig(**self.get_conf('network', {}))
self.training_folder = self.get_conf('training_folder', self.job.training_folder)
self.train_config = TrainConfig(**self.get_conf('train', {}))
self.model_config = ModelConfig(**self.get_conf('model', {}))
self.save_config = SaveConfig(**self.get_conf('save', {}))
self.sample_config = SampleConfig(**self.get_conf('sample', {}))
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
self.sd = None
self.prompt_settings = self.get_prompt_settings()
# added later
self.network = None
self.scheduler = None
self.is_flipped = False
def flip_network(self):
for param in self.network.parameters():
# apply opposite weight to the network
param.data = -param.data
self.is_flipped = not self.is_flipped
def get_prompt_settings(self):
prompts = self.get_conf('prompts', required=True)
prompt_settings = [PromptSettingsOld(**prompt) for prompt in prompts]
# for i, prompt in enumerate(prompts):
# prompt_settings[i].fill_prompts(prompt)
return prompt_settings
def sample(self, step=None):
sample_folder = os.path.join(self.save_root, 'samples')
if not os.path.exists(sample_folder):
os.makedirs(sample_folder, exist_ok=True)
self.network.eval()
# save current seed state for training
rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
original_device_dict = {
'vae': self.sd.vae.device,
'unet': self.sd.unet.device,
'text_encoder': self.sd.text_encoder.device,
# 'tokenizer': self.sd.tokenizer.device,
}
self.sd.vae.to(self.device_torch)
self.sd.unet.to(self.device_torch)
self.sd.text_encoder.to(self.device_torch)
# self.sd.tokenizer.to(self.device_torch)
# TODO add clip skip
pipeline = StableDiffusionPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder,
tokenizer=self.sd.tokenizer,
scheduler=self.sd.noise_scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
# disable progress bar
pipeline.set_progress_bar_config(disable=True)
start_seed = self.sample_config.seed
current_seed = start_seed
pipeline.to(self.device_torch)
with self.network:
with torch.no_grad():
assert self.network.is_active
if self.logging_config.verbose:
print("network_state", {
'is_active': self.network.is_active,
'multiplier': self.network.multiplier,
})
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}"):
raw_prompt = self.sample_config.prompts[i]
prompt = raw_prompt
neg = self.sample_config.neg
p_split = raw_prompt.split('--n')
if len(p_split) > 1:
prompt = p_split[0].strip()
neg = p_split[1].strip()
height = self.sample_config.height
width = self.sample_config.width
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
if self.sample_config.walk_seed:
current_seed += i
torch.manual_seed(current_seed)
torch.cuda.manual_seed(current_seed)
img = pipeline(
prompt,
height=height,
width=width,
num_inference_steps=self.sample_config.sample_steps,
guidance_scale=self.sample_config.guidance_scale,
negative_prompt=neg,
).images[0]
step_num = ''
if step is not None:
# zero-pad 9 digits
step_num = f"_{str(step).zfill(9)}"
seconds_since_epoch = int(time.time())
# zero-pad 2 digits
i_str = str(i).zfill(2)
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
output_path = os.path.join(sample_folder, filename)
img.save(output_path)
# clear pipeline and cache to reduce vram usage
del pipeline
torch.cuda.empty_cache()
# restore training state
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
self.sd.vae.to(original_device_dict['vae'])
self.sd.unet.to(original_device_dict['unet'])
self.sd.text_encoder.to(original_device_dict['text_encoder'])
self.network.train()
# self.sd.tokenizer.to(original_device_dict['tokenizer'])
def update_training_metadata(self):
self.add_meta(OrderedDict({"training_info": self.get_training_info()}))
def get_training_info(self):
info = OrderedDict({
'step': self.step_num
})
return info
def save(self, step=None):
if not os.path.exists(self.save_root):
os.makedirs(self.save_root, exist_ok=True)
step_num = ''
if step is not None:
# zeropad 9 digits
step_num = f"_{str(step).zfill(9)}"
self.update_training_metadata()
filename = f'{self.job.name}{step_num}.safetensors'
file_path = os.path.join(self.save_root, filename)
# prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
self.network.save_weights(
file_path,
dtype=get_torch_dtype(self.save_config.dtype),
metadata=save_meta
)
self.print(f"Saved to {file_path}")
def run(self):
super().run()
dtype = get_torch_dtype(self.train_config.dtype)
modules = DEFAULT_TARGET_REPLACE
loss = None
if self.network_config.type == "c3lier":
modules += UNET_TARGET_REPLACE_MODULE_CONV
tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models(
self.model_config.name_or_path,
scheduler_name=self.train_config.noise_scheduler,
v2=self.model_config.is_v2,
v_pred=self.model_config.is_v_pred,
)
# just for now or of we want to load a custom one
# put on cpu for now, we only need it when sampling
vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype)
vae.eval()
self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler)
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.eval()
unet.to(self.device_torch, dtype=dtype)
if self.train_config.xformers:
unet.enable_xformers_memory_efficient_attention()
unet.requires_grad_(False)
unet.eval()
self.network = LoRASpecialNetwork(
text_encoder=text_encoder,
unet=unet,
lora_dim=self.network_config.rank,
multiplier=1.0,
alpha=self.network_config.alpha
)
self.network.force_to(self.device_torch, dtype=dtype)
self.network.apply_to(
text_encoder,
unet,
self.train_config.train_text_encoder,
self.train_config.train_unet
)
self.network.prepare_grad_etc(text_encoder, unet)
optimizer_type = self.train_config.optimizer.lower()
# we call it something different than leco
if optimizer_type == "dadaptation":
optimizer_type = "dadaptadam"
optimizer_module = train_util.get_optimizer(optimizer_type)
optimizer = optimizer_module(
self.network.prepare_optimizer_params(
self.train_config.lr, self.train_config.lr, self.train_config.lr
),
lr=self.train_config.lr
)
lr_scheduler = train_util.get_lr_scheduler(
self.train_config.lr_scheduler,
optimizer,
max_iterations=self.train_config.steps,
lr_min=self.train_config.lr / 100, # not sure why leco did this, but ill do it to
)
criteria = torch.nn.MSELoss()
if self.logging_config.verbose:
print("Prompts")
for settings in self.prompt_settings:
print(settings)
# debug
# debug_util.check_requires_grad(network)
# debug_util.check_training_mode(network)
cache = PromptEmbedsCache()
prompt_pairs: list[PromptEmbedsPair] = []
with torch.no_grad():
for settings in self.prompt_settings:
self.print(settings)
for prompt in [
settings.target,
settings.positive,
settings.neutral,
settings.unconditional,
]:
if cache[prompt] == None:
cache[prompt] = train_util.encode_prompts(
tokenizer, text_encoder, [prompt]
)
prompt_pairs.append(
PromptEmbedsPair(
criteria,
cache[settings.target],
cache[settings.positive],
cache[settings.unconditional],
cache[settings.neutral],
settings,
)
)
# move to cpu to save vram
# tokenizer.to("cpu")
text_encoder.to("cpu")
flush()
# sample first
self.print("Generating baseline samples before training")
self.sample(0)
self.progress_bar = tqdm(range(self.train_config.steps))
self.progress_bar = tqdm(
total=self.train_config.steps,
desc=self.job.name,
leave=True
)
self.step_num = 0
for step in range(self.train_config.steps):
with torch.no_grad():
noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
)
optimizer.zero_grad()
prompt_pair: PromptEmbedsPair = prompt_pairs[
torch.randint(0, len(prompt_pairs), (1,)).item()
]
# 1 ~ 49 random from 1 to 49
timesteps_to = torch.randint(
1, self.train_config.max_denoising_steps, (1,)
).item()
height, width = (
prompt_pair.resolution,
prompt_pair.resolution,
)
if prompt_pair.dynamic_resolution:
height, width = train_util.get_random_resolution_in_bucket(
prompt_pair.resolution
)
if self.logging_config.verbose:
self.print("guidance_scale:", prompt_pair.guidance_scale)
self.print("resolution:", prompt_pair.resolution)
self.print("dynamic_resolution:", prompt_pair.dynamic_resolution)
if prompt_pair.dynamic_resolution:
self.print("bucketed resolution:", (height, width))
self.print("batch_size:", prompt_pair.batch_size)
latents = train_util.get_initial_latents(
noise_scheduler, prompt_pair.batch_size, height, width, 1
).to(self.device_torch, dtype=dtype)
with self.network:
assert self.network.is_active
# A little denoised one is returned
denoised_latents = train_util.diffusion(
unet,
noise_scheduler,
latents, # pass simple noise latents
train_util.concat_embeddings(
prompt_pair.unconditional,
prompt_pair.target,
prompt_pair.batch_size,
),
start_timesteps=0,
total_timesteps=timesteps_to,
guidance_scale=3,
)
noise_scheduler.set_timesteps(1000)
current_timestep = noise_scheduler.timesteps[
int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
]
# with network: Only empty LoRA is enabled outside with network :
positive_latents = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
prompt_pair.unconditional,
prompt_pair.positive,
prompt_pair.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
neutral_latents = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
prompt_pair.unconditional,
prompt_pair.neutral,
prompt_pair.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
unconditional_latents = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
prompt_pair.unconditional,
prompt_pair.unconditional,
prompt_pair.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
# if self.logging_config.verbose:
# self.print("positive_latents:", positive_latents[0, 0, :5, :5])
# self.print("neutral_latents:", neutral_latents[0, 0, :5, :5])
# self.print("unconditional_latents:", unconditional_latents[0, 0, :5, :5])
with self.network:
target_latents = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
prompt_pair.unconditional,
prompt_pair.target,
prompt_pair.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
# if self.logging_config.verbose:
# self.print("target_latents:", target_latents[0, 0, :5, :5])
positive_latents.requires_grad = False
neutral_latents.requires_grad = False
unconditional_latents.requires_grad = False
loss = prompt_pair.loss(
target_latents=target_latents,
positive_latents=positive_latents,
neutral_latents=neutral_latents,
unconditional_latents=unconditional_latents,
)
loss_float = loss.item()
if self.train_config.optimizer.startswith('dadaptation'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']
self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e} loss: {loss.item():.3e}")
loss.backward()
optimizer.step()
lr_scheduler.step()
del (
positive_latents,
neutral_latents,
unconditional_latents,
target_latents,
latents,
)
flush()
# don't do on first step
if self.step_num != self.start_step:
# pause progress bar
self.progress_bar.unpause() # makes it so doesn't track time
if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0:
# print above the progress bar
self.sample(self.step_num)
if self.save_config.save_every and self.step_num % self.save_config.save_every == 0:
# print above the progress bar
self.print(f"Saving at step {self.step_num}")
self.save(self.step_num)
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
# log to tensorboard
if self.writer is not None:
# get avg loss
self.writer.add_scalar(f"loss", loss_float, self.step_num)
if self.train_config.optimizer.startswith('dadaptation'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
self.progress_bar.refresh()
# sets progress bar to match out step
self.progress_bar.update(step - self.progress_bar.n)
# end of step
self.step_num = step
self.save()
del (
unet,
noise_scheduler,
loss,
optimizer,
self.network,
tokenizer,
text_encoder,
)
flush()

View File

@@ -178,7 +178,6 @@ class TrainVAEProcess(BaseTrainProcess):
self.device = self.get_conf('device', self.job.device) self.device = self.get_conf('device', self.job.device)
self.vae_path = self.get_conf('vae_path', required=True) self.vae_path = self.get_conf('vae_path', required=True)
self.datasets_objects = self.get_conf('datasets', required=True) self.datasets_objects = self.get_conf('datasets', required=True)
self.training_folder = self.get_conf('training_folder', self.job.training_folder)
self.batch_size = self.get_conf('batch_size', 1, as_type=int) self.batch_size = self.get_conf('batch_size', 1, as_type=int)
self.resolution = self.get_conf('resolution', 256, as_type=int) self.resolution = self.get_conf('resolution', 256, as_type=int)
self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float) self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float)
@@ -197,14 +196,10 @@ class TrainVAEProcess(BaseTrainProcess):
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float) self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float) self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float) self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float)
self.first_step = 0
self.blocks_to_train = self.get_conf('blocks_to_train', ['all']) self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
self.writer = self.job.writer
self.torch_dtype = get_torch_dtype(self.dtype) self.torch_dtype = get_torch_dtype(self.dtype)
self.save_root = os.path.join(self.training_folder, self.job.name)
self.vgg_19 = None self.vgg_19 = None
self.progress_bar = None
self.style_weight_scalers = [] self.style_weight_scalers = []
self.content_weight_scalers = [] self.content_weight_scalers = []
@@ -254,13 +249,6 @@ class TrainVAEProcess(BaseTrainProcess):
}) })
return info return info
def print(self, message, **kwargs):
if self.progress_bar is not None:
self.progress_bar.write(message, **kwargs)
self.progress_bar.update()
else:
print(message, **kwargs)
def load_datasets(self): def load_datasets(self):
if self.data_loader is None: if self.data_loader is None:
print(f"Loading datasets") print(f"Loading datasets")

View File

@@ -5,3 +5,4 @@ from .BaseProcess import BaseProcess
from .BaseTrainProcess import BaseTrainProcess from .BaseTrainProcess import BaseTrainProcess
from .TrainVAEProcess import TrainVAEProcess from .TrainVAEProcess import TrainVAEProcess
from .BaseMergeProcess import BaseMergeProcess from .BaseMergeProcess import BaseMergeProcess
from .TrainSliderProcess import TrainSliderProcess

View File

@@ -8,4 +8,5 @@ flatten_json
accelerator accelerator
pyyaml pyyaml
oyaml oyaml
tensorboard tensorboard
kornia

238
toolkit/lora.py Normal file
View File

@@ -0,0 +1,238 @@
# ref:
# - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
# - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
# - https://github.com/p1atdev/LECO/blob/main/lora.py
import os
import math
from typing import Optional, List, Type, Set, Literal
import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel
from safetensors.torch import save_file
UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
"Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
]
UNET_TARGET_REPLACE_MODULE_CONV = [
"ResnetBlock2D",
"Downsample2D",
"Upsample2D",
] # locon, 3clier
LORA_PREFIX_UNET = "lora_unet"
DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER
TRAINING_METHODS = Literal[
"noxattn", # train all layers except x-attns and time_embed layers
"innoxattn", # train all layers except self attention layers
"selfattn", # ESD-u, train only self attention layers
"xattn", # ESD-x, train only x attention layers
"full", # train all layers
# "notime",
# "xlayer",
# "outxattn",
# "outsattn",
# "inxattn",
# "inmidsattn",
# "selflayer",
]
class LoRAModule(nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
self.lora_name = lora_name
self.lora_dim = lora_dim
if org_module.__class__.__name__ == "Linear":
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
elif org_module.__class__.__name__ == "Conv2d": # 一応
in_dim = org_module.in_channels
out_dim = org_module.out_channels
self.lora_dim = min(self.lora_dim, in_dim, out_dim)
if self.lora_dim != lora_dim:
print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = nn.Conv2d(
in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
)
self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()
alpha = lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
# same as microsoft's
nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_up.weight)
self.multiplier = multiplier
self.org_module = org_module # remove in applying
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
def forward(self, x):
return (
self.org_forward(x)
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
)
class LoRANetwork(nn.Module):
def __init__(
self,
unet: UNet2DConditionModel,
rank: int = 4,
multiplier: float = 1.0,
alpha: float = 1.0,
train_method: TRAINING_METHODS = "full",
) -> None:
super().__init__()
self.multiplier = multiplier
self.lora_dim = rank
self.alpha = alpha
# LoRAのみ
self.module = LoRAModule
# unetのloraを作る
self.unet_loras = self.create_modules(
LORA_PREFIX_UNET,
unet,
DEFAULT_TARGET_REPLACE,
self.lora_dim,
self.multiplier,
train_method=train_method,
)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
# assertion 名前の被りがないか確認しているようだ
lora_names = set()
for lora in self.unet_loras:
assert (
lora.lora_name not in lora_names
), f"duplicated lora name: {lora.lora_name}. {lora_names}"
lora_names.add(lora.lora_name)
# 適用する
for lora in self.unet_loras:
lora.apply_to()
self.add_module(
lora.lora_name,
lora,
)
del unet
torch.cuda.empty_cache()
def create_modules(
self,
prefix: str,
root_module: nn.Module,
target_replace_modules: List[str],
rank: int,
multiplier: float,
train_method: TRAINING_METHODS,
) -> list:
loras = []
for name, module in root_module.named_modules():
if train_method == "noxattn": # Cross Attention と Time Embed 以外学習
if "attn2" in name or "time_embed" in name:
continue
elif train_method == "innoxattn": # Cross Attention 以外学習
if "attn2" in name:
continue
elif train_method == "selfattn": # Self Attention のみ学習
if "attn1" not in name:
continue
elif train_method == "xattn": # Cross Attention のみ学習
if "attn2" not in name:
continue
elif train_method == "full": # 全部学習
pass
else:
raise NotImplementedError(
f"train_method: {train_method} is not implemented."
)
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ in ["Linear", "Conv2d"]:
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
print(f"{lora_name}")
lora = self.module(
lora_name, child_module, multiplier, rank, self.alpha
)
loras.append(lora)
return loras
def prepare_optimizer_params(self):
all_params = []
if self.unet_loras: # 実質これしかない
params = []
[params.extend(lora.parameters()) for lora in self.unet_loras]
param_data = {"params": params}
all_params.append(param_data)
return all_params
def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
for key in list(state_dict.keys()):
if not key.startswith("lora"):
# lora以外除外
del state_dict[key]
if os.path.splitext(file)[1] == ".safetensors":
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def __enter__(self):
for lora in self.unet_loras:
lora.multiplier = 1.0
def __exit__(self, exc_type, exc_value, tb):
for lora in self.unet_loras:
lora.multiplier = 0

226
toolkit/lora_special.py Normal file
View File

@@ -0,0 +1,226 @@
import os
import sys
from typing import List
import torch
from .paths import SD_SCRIPTS_ROOT
sys.path.append(SD_SCRIPTS_ROOT)
from networks.lora import LoRANetwork, LoRAModule, get_block_index
class LoRASpecialNetwork(LoRANetwork):
_multiplier: float = 1.0
is_active: bool = False
def __init__(
self,
text_encoder,
unet,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=None,
rank_dropout=None,
module_dropout=None,
conv_lora_dim=None,
conv_alpha=None,
block_dims=None,
block_alphas=None,
conv_block_dims=None,
conv_block_alphas=None,
modules_dim=None,
modules_alpha=None,
module_class=LoRAModule,
varbose=False,
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
1. lora_dimとalphaを指定
2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない
4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
5. modules_dimとmodules_alphaを指定 (推論用)
"""
# call the parent of the parent we are replacing (LoRANetwork) init
super(LoRANetwork, self).__init__()
self.multiplier = multiplier
self.lora_dim = lora_dim
self.alpha = alpha
self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
if modules_dim is not None:
print(f"create LoRA network from weights")
elif block_dims is not None:
print(f"create LoRA network from block_dims")
print(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
print(f"block_dims: {block_dims}")
print(f"block_alphas: {block_alphas}")
if conv_block_dims is not None:
print(f"conv_block_dims: {conv_block_dims}")
print(f"conv_block_alphas: {conv_block_alphas}")
else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
print(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
if self.conv_lora_dim is not None:
print(
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
# create module instances
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER
loras = []
skipped = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
dim = None
alpha = None
if modules_dim is not None:
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
elif is_unet and block_dims is not None:
block_idx = get_block_index(lora_name)
if is_linear or is_conv2d_1x1:
dim = block_dims[block_idx]
alpha = block_alphas[block_idx]
elif conv_block_dims is not None:
dim = conv_block_dims[block_idx]
alpha = conv_block_alphas[block_idx]
else:
if is_linear or is_conv2d_1x1:
dim = self.lora_dim
alpha = self.alpha
elif self.conv_lora_dim is not None:
dim = self.conv_lora_dim
alpha = self.conv_alpha
if dim is None or dim == 0:
if is_linear or is_conv2d_1x1 or (
self.conv_lora_dim is not None or conv_block_dims is not None):
skipped.append(lora_name)
continue
lora = module_class(
lora_name,
child_module,
self.multiplier,
dim,
alpha,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
)
loras.append(lora)
return loras, skipped
self.text_encoder_loras, skipped_te = create_modules(False, text_encoder,
LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras, skipped_un = create_modules(True, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
skipped = skipped_te + skipped_un
if varbose and len(skipped) > 0:
print(
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)
for name in skipped:
print(f"\t{name}")
self.up_lr_weight: List[float] = None
self.down_lr_weight: List[float] = None
self.mid_lr_weight: float = None
self.block_lr = False
# assertion
names = set()
for lora in self.text_encoder_loras + self.unet_loras:
# doesnt work on new diffusers. TODO make sure we are not missing something
# assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
@property
def multiplier(self):
return self._multiplier
@multiplier.setter
def multiplier(self, value):
self._multiplier = value
self._update_lora_multiplier()
def _update_lora_multiplier(self):
if self.is_active:
if hasattr(self, 'unet_loras'):
for lora in self.unet_loras:
lora.multiplier = self._multiplier
if hasattr(self, 'text_encoder_loras'):
for lora in self.text_encoder_loras:
lora.multiplier = self._multiplier
else:
if hasattr(self, 'unet_loras'):
for lora in self.unet_loras:
lora.multiplier = 0
if hasattr(self, 'text_encoder_loras'):
for lora in self.text_encoder_loras:
lora.multiplier = 0
def __enter__(self):
self.is_active = True
self._update_lora_multiplier()
def __exit__(self, exc_type, exc_value, tb):
self.is_active = False
self._update_lora_multiplier()
def force_to(self, device, dtype):
self.to(device, dtype)
loras = []
if hasattr(self, 'unet_loras'):
loras += self.unet_loras
if hasattr(self, 'text_encoder_loras'):
loras += self.text_encoder_loras
for lora in loras:
lora.to(device, dtype)

View File

@@ -83,3 +83,5 @@ class PatternLoss(torch.nn.Module):
g_chan_loss = torch.abs(separated_chan_loss(g_chans) - separated_chan_loss(g_chans_target)) g_chan_loss = torch.abs(separated_chan_loss(g_chans) - separated_chan_loss(g_chans_target))
b_chan_loss = torch.abs(separated_chan_loss(b_chans) - separated_chan_loss(b_chans_target)) b_chan_loss = torch.abs(separated_chan_loss(b_chans) - separated_chan_loss(b_chans_target))
return (r_chan_loss + g_chan_loss + b_chan_loss) * 0.3333 return (r_chan_loss + g_chan_loss + b_chan_loss) * 0.3333