SDXL should be working, but I broke something where it is not converging.

This commit is contained in:
Jaret Burkett
2023-07-25 13:50:59 -06:00
parent 52f02d53f1
commit cb70c03273
11 changed files with 458 additions and 166 deletions

View File

@@ -17,6 +17,7 @@ class BaseJob:
raise ValueError('config is required')
self.config = config['config']
self.raw_config = config
self.job = config['job']
self.name = self.get_conf('name', required=True)
if 'meta' in config:

View File

@@ -1,3 +1,4 @@
import json
import os
from jobs import BaseJob
@@ -6,7 +7,7 @@ from collections import OrderedDict
from typing import List
from jobs.process import BaseExtractProcess, TrainFineTuneProcess
from datetime import datetime
import yaml
from toolkit.paths import REPOS_ROOT
import sys
@@ -16,6 +17,7 @@ sys.path.append(REPOS_ROOT)
process_dict = {
'vae': 'TrainVAEProcess',
'slider': 'TrainSliderProcess',
'lora_hack': 'TrainLoRAHack',
}
@@ -37,6 +39,13 @@ class TrainJob(BaseJob):
# loads the processes from the config
self.load_processes(process_dict)
def save_training_config(self):
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
os.makedirs(self.training_folder, exist_ok=True)
save_dif = os.path.join(self.training_folder, f'run_config_{timestamp}.yaml')
with open(save_dif, 'w') as f:
yaml.dump(self.raw_config, f)
def run(self):
super().run()
print("")

View File

@@ -2,7 +2,6 @@ import time
from collections import OrderedDict
import os
from leco.train_util import predict_noise
from toolkit.kohya_model_util import load_vae
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer
@@ -12,7 +11,7 @@ import sys
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from jobs.process import BaseTrainProcess
from toolkit.metadata import get_meta_for_safetensors
@@ -24,6 +23,7 @@ from tqdm import tqdm
from leco import train_util, model_util
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
def flush():
@@ -35,15 +35,6 @@ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
class StableDiffusion:
def __init__(self, vae, tokenizer, text_encoder, unet, noise_scheduler):
self.vae = vae
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.unet = unet
self.noise_scheduler = noise_scheduler
class BaseSDTrainProcess(BaseTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
@@ -80,26 +71,44 @@ class BaseSDTrainProcess(BaseTrainProcess):
original_device_dict = {
'vae': self.sd.vae.device,
'unet': self.sd.unet.device,
'text_encoder': self.sd.text_encoder.device,
# 'tokenizer': self.sd.tokenizer.device,
}
# handle sdxl text encoder
if isinstance(self.sd.text_encoder, list):
for encoder, i in zip(self.sd.text_encoder, range(len(self.sd.text_encoder))):
original_device_dict[f'text_encoder_{i}'] = encoder.device
encoder.to(self.device_torch)
else:
original_device_dict['text_encoder'] = self.sd.text_encoder.device
self.sd.text_encoder.to(self.device_torch)
self.sd.vae.to(self.device_torch)
self.sd.unet.to(self.device_torch)
self.sd.text_encoder.to(self.device_torch)
# self.sd.text_encoder.to(self.device_torch)
# self.sd.tokenizer.to(self.device_torch)
# TODO add clip skip
pipeline = StableDiffusionPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder,
tokenizer=self.sd.tokenizer,
scheduler=self.sd.noise_scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
if self.sd.is_xl:
pipeline = StableDiffusionXLPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder[0],
text_encoder_2=self.sd.text_encoder[1],
tokenizer=self.sd.tokenizer[0],
tokenizer_2=self.sd.tokenizer[1],
scheduler=self.sd.noise_scheduler,
)
else:
pipeline = StableDiffusionPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder,
tokenizer=self.sd.tokenizer,
scheduler=self.sd.noise_scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
# disable progress bar
pipeline.set_progress_bar_config(disable=True)
@@ -118,7 +127,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
'multiplier': self.network.multiplier,
})
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}", leave=False):
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}",
leave=False):
raw_prompt = self.sample_config.prompts[i]
neg = self.sample_config.neg
@@ -180,7 +190,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd.vae.to(original_device_dict['vae'])
self.sd.unet.to(original_device_dict['unet'])
self.sd.text_encoder.to(original_device_dict['text_encoder'])
if isinstance(self.sd.text_encoder, list):
for encoder, i in zip(self.sd.text_encoder, range(len(self.sd.text_encoder))):
encoder.to(original_device_dict[f'text_encoder_{i}'])
else:
self.sd.text_encoder.to(original_device_dict['text_encoder'])
if self.network is not None:
self.network.train()
self.network.multiplier = start_multiplier
@@ -267,23 +281,90 @@ class BaseSDTrainProcess(BaseTrainProcess):
# return loss
return 0.0
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
def diffuse_some_steps(
def get_time_ids_from_latents(self, latents):
bs, ch, h, w = list(latents.shape)
height = h * VAE_SCALE_FACTOR
width = w * VAE_SCALE_FACTOR
dtype = get_torch_dtype(self.train_config.dtype)
if self.sd.is_xl:
prompt_ids = train_util.get_add_time_ids(
height,
width,
dynamic_crops=False, # look into this
dtype=dtype,
).to(self.device_torch, dtype=dtype)
return train_util.concat_embeddings(
prompt_ids, prompt_ids, bs
)
else:
return None
def predict_noise(
self,
latents: torch.FloatTensor,
text_embeddings: torch.FloatTensor,
total_timesteps: int = 1000,
start_timesteps=0,
text_embeddings: PromptEmbeds,
timestep: int,
guidance_scale=7.5,
guidance_rescale=0.7,
add_time_ids=None,
**kwargs,
):
for timestep in tqdm(self.sd.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False):
noise_pred = train_util.predict_noise(
self.sd.unet, self.sd.noise_scheduler, timestep, latents, text_embeddings, **kwargs
if self.sd.is_xl:
if add_time_ids is None:
add_time_ids = self.get_time_ids_from_latents(latents)
# todo LECOs code looks like it is omitting noise_pred
noise_pred = train_util.predict_noise_xl(
self.sd.unet,
self.sd.noise_scheduler,
timestep,
latents,
text_embeddings.text_embeds,
text_embeddings.pooled_embeds,
add_time_ids,
guidance_scale=guidance_scale,
guidance_rescale=guidance_rescale
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample
else:
noise_pred = train_util.predict_noise(
self.sd.unet,
self.sd.noise_scheduler,
timestep,
latents,
text_embeddings.text_embeds,
guidance_scale=guidance_scale
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample
return latents
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
def diffuse_some_steps(
self,
latents: torch.FloatTensor,
text_embeddings: PromptEmbeds,
total_timesteps: int = 1000,
start_timesteps=0,
guidance_scale=1,
add_time_ids=None,
**kwargs,
):
for timestep in tqdm(self.sd.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False):
latents = self.predict_noise(
latents,
text_embeddings,
timestep,
guidance_scale=guidance_scale,
add_time_ids=add_time_ids,
**kwargs,
)
# return latents_steps
return latents
@@ -296,20 +377,35 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype)
tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models(
self.model_config.name_or_path,
scheduler_name=self.train_config.noise_scheduler,
v2=self.model_config.is_v2,
v_pred=self.model_config.is_v_pred,
)
if self.model_config.is_xl:
tokenizer, text_encoders, unet, noise_scheduler = model_util.load_models_xl(
self.model_config.name_or_path,
scheduler_name=self.train_config.noise_scheduler,
weight_dtype=dtype,
)
for text_encoder in text_encoders:
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
text_encoder = text_encoders
else:
tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models(
self.model_config.name_or_path,
scheduler_name=self.train_config.noise_scheduler,
v2=self.model_config.is_v2,
v_pred=self.model_config.is_v_pred,
)
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.eval()
# just for now or of we want to load a custom one
# put on cpu for now, we only need it when sampling
vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype)
vae.eval()
self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler)
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.eval()
self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler, is_xl=self.model_config.is_xl)
unet.to(self.device_torch, dtype=dtype)
if self.train_config.xformers:
@@ -323,7 +419,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
unet=unet,
lora_dim=self.network_config.rank,
multiplier=1.0,
alpha=self.network_config.alpha
alpha=self.network_config.alpha,
train_unet=self.train_config.train_unet,
train_text_encoder=self.train_config.train_text_encoder,
)
self.network.force_to(self.device_torch, dtype=dtype)
@@ -376,8 +474,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.hook_before_train_loop()
# sample first
self.print("Generating baseline samples before training")
self.sample(0)
if self.train_config.skip_first_sample:
self.print("Skipping first sample due to config setting")
else:
self.print("Generating baseline samples before training")
self.sample(0)
self.progress_bar = tqdm(
total=self.train_config.steps,

View File

@@ -0,0 +1,76 @@
# ref:
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
import time
from collections import OrderedDict
import os
from toolkit.config_modules import SliderConfig
from toolkit.paths import REPOS_ROOT
import sys
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import gc
import torch
from leco import train_util, model_util
from leco.prompt_util import PromptEmbedsCache
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
def flush():
torch.cuda.empty_cache()
gc.collect()
class LoRAHack:
def __init__(self, **kwargs):
self.type = kwargs.get('type', 'suppression')
class TrainLoRAHack(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.hack_config = LoRAHack(**self.get_conf('hack', {}))
def hook_before_train_loop(self):
# we don't need text encoder so move it to cpu
self.sd.text_encoder.to("cpu")
flush()
# end hook_before_train_loop
if self.hack_config.type == 'suppression':
# set all params to self.current_suppression
params = self.network.parameters()
for param in params:
# get random noise for each param
noise = torch.randn_like(param) - 0.5
# apply noise to param
param.data = noise * 0.001
def supress_loop(self):
dtype = get_torch_dtype(self.train_config.dtype)
loss_dict = OrderedDict(
{'sup': 0.0}
)
# increase noise
for param in self.network.parameters():
# get random noise for each param
noise = torch.randn_like(param) - 0.5
# apply noise to param
param.data = param.data + noise * 0.001
return loss_dict
def hook_train_loop(self):
if self.hack_config.type == 'suppression':
return self.supress_loop()
else:
raise NotImplementedError(f'unknown hack type: {self.hack_config.type}')
# end hook_train_loop

View File

@@ -3,19 +3,22 @@
import time
from collections import OrderedDict
import os
from typing import Optional
from toolkit.config_modules import SliderConfig
from toolkit.paths import REPOS_ROOT
import sys
from toolkit.stable_diffusion_model import PromptEmbeds
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import gc
from toolkit import train_tools
import torch
from leco import train_util, model_util
from leco.prompt_util import PromptEmbedsCache
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
@@ -29,7 +32,6 @@ def flush():
gc.collect()
class EncodedPromptPair:
def __init__(
self,
@@ -54,6 +56,19 @@ class EncodedPromptPair:
self.weight = weight
class PromptEmbedsCache: # 使いまわしたいので
prompts: dict[str, PromptEmbeds] = {}
def __setitem__(self, __name: str, __value: PromptEmbeds) -> None:
self.prompts[__name] = __value
def __getitem__(self, __name: str) -> Optional[PromptEmbeds]:
if __name in self.prompts:
return self.prompts[__name]
else:
return None
class EncodedAnchor:
def __init__(
self,
@@ -89,19 +104,17 @@ class TrainSliderProcess(BaseSDTrainProcess):
with torch.no_grad():
neutral = ""
for target in self.slider_config.targets:
# build the cache
for prompt in [
target.target_class,
target.positive,
target.negative,
neutral # empty neutral
]:
if cache[prompt] is None:
cache[prompt] = self.sd.encode_prompt(prompt)
for resolution in self.slider_config.resolutions:
width, height = resolution
# build the cache
for prompt in [
target.target_class,
target.positive,
target.negative,
neutral # empty neutral
]:
if cache[prompt] == None:
cache[prompt] = train_util.encode_prompts(
self.sd.tokenizer, self.sd.text_encoder, [prompt]
)
only_erase = len(target.positive.strip()) == 0
only_enhance = len(target.negative.strip()) == 0
@@ -184,9 +197,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
anchor.neg_prompt # empty neutral
]:
if cache[prompt] == None:
cache[prompt] = train_util.encode_prompts(
self.sd.tokenizer, self.sd.text_encoder, [prompt]
)
cache[prompt] = self.sd.encode_prompt(prompt)
anchor_pairs += [
EncodedAnchor(
@@ -198,7 +209,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
# move to cpu to save vram
# We don't need text encoder anymore, but keep it on cpu for sampling
self.sd.text_encoder.to("cpu")
# if text encoder is list
if isinstance(self.sd.text_encoder, list):
for encoder in self.sd.text_encoder:
encoder.to("cpu")
else:
self.sd.text_encoder.to("cpu")
self.prompt_cache = cache
self.prompt_pairs = prompt_pairs
self.anchor_pairs = anchor_pairs
@@ -220,6 +236,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
negative = prompt_pair.negative
positive = prompt_pair.positive
weight = prompt_pair.weight
multiplier = prompt_pair.multiplier
unet = self.sd.unet
noise_scheduler = self.sd.noise_scheduler
@@ -227,8 +244,20 @@ class TrainSliderProcess(BaseSDTrainProcess):
lr_scheduler = self.lr_scheduler
loss_function = torch.nn.MSELoss()
def get_noise_pred(p, n):
return self.predict_noise(
latents=denoised_latents,
text_embeddings=train_tools.concat_prompt_embeddings(
p, # unconditional
n, # positive
self.train_config.batch_size,
),
timestep=current_timestep,
guidance_scale=1,
)
# set network multiplier
self.network.multiplier = prompt_pair.multiplier
self.network.multiplier = multiplier
with torch.no_grad():
self.sd.noise_scheduler.set_timesteps(
@@ -254,9 +283,10 @@ class TrainSliderProcess(BaseSDTrainProcess):
with self.network:
assert self.network.is_active
self.network.multiplier = multiplier
denoised_latents = self.diffuse_some_steps(
latents, # pass simple noise latents
train_util.concat_embeddings(
train_tools.concat_prompt_embeddings(
positive, # unconditional
target_class, # target
self.train_config.batch_size,
@@ -272,43 +302,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
]
# with network: 0 weight LoRA is enabled outside "with network:"
positive_latents = train_util.predict_noise( # positive_latents
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
positive, # unconditional
negative, # positive
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
neutral_latents = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
positive, # unconditional
neutral, # neutral
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
unconditional_latents = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
positive, # unconditional
positive, # unconditional
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
positive_latents = get_noise_pred(positive, negative)
neutral_latents = get_noise_pred(positive, neutral)
unconditional_latents = get_noise_pred(positive, positive)
anchor_loss = None
if len(self.anchor_pairs) > 0:
@@ -317,51 +315,19 @@ class TrainSliderProcess(BaseSDTrainProcess):
torch.randint(0, len(self.anchor_pairs), (1,)).item()
]
with torch.no_grad():
anchor_target_noise = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
anchor.prompt,
anchor.neg_prompt,
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
anchor_target_noise = get_noise_pred(anchor.prompt, anchor.neg_prompt)
with self.network:
# anchor whatever weight prompt pair is using
pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0
self.network.multiplier = anchor.multiplier * pos_nem_mult
anchor_pred_noise = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
anchor.prompt,
anchor.neg_prompt,
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
anchor_pred_noise = get_noise_pred(anchor.prompt, anchor.neg_prompt)
self.network.multiplier = prompt_pair.multiplier
with self.network:
self.network.multiplier = prompt_pair.multiplier
target_latents = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
train_util.concat_embeddings(
positive, # unconditional
target_class, # target
self.train_config.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
target_latents = get_noise_pred(positive, target_class)
# if self.logging_config.verbose:
# self.print("target_latents:", target_latents[0, 0, :5, :5])

View File

@@ -6,3 +6,4 @@ from .BaseTrainProcess import BaseTrainProcess
from .TrainVAEProcess import TrainVAEProcess
from .BaseMergeProcess import BaseMergeProcess
from .TrainSliderProcess import TrainSliderProcess
from .TrainLoRAHack import TrainLoRAHack

View File

@@ -9,4 +9,5 @@ accelerator
pyyaml
oyaml
tensorboard
kornia
kornia
invisible-watermark

View File

@@ -50,12 +50,14 @@ class TrainConfig:
self.train_text_encoder = kwargs.get('train_text_encoder', True)
self.noise_offset = kwargs.get('noise_offset', 0.0)
self.optimizer_params = kwargs.get('optimizer_params', {})
self.skip_first_sample = kwargs.get('skip_first_sample', False)
class ModelConfig:
def __init__(self, **kwargs):
self.name_or_path: str = kwargs.get('name_or_path', None)
self.is_v2: bool = kwargs.get('is_v2', False)
self.is_xl: bool = kwargs.get('is_xl', False)
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
if self.name_or_path is None:

View File

@@ -1,8 +1,10 @@
import os
import sys
from typing import List
from typing import List, Optional, Dict, Type, Union
import torch
from transformers import CLIPTextModel
from .paths import SD_SCRIPTS_ROOT
sys.path.append(SD_SCRIPTS_ROOT)
@@ -14,26 +16,40 @@ class LoRASpecialNetwork(LoRANetwork):
_multiplier: float = 1.0
is_active: bool = False
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
def __init__(
self,
text_encoder,
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
unet,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=None,
rank_dropout=None,
module_dropout=None,
conv_lora_dim=None,
conv_alpha=None,
block_dims=None,
block_alphas=None,
conv_block_dims=None,
conv_block_alphas=None,
modules_dim=None,
modules_alpha=None,
module_class=LoRAModule,
varbose=False,
multiplier: float = 1.0,
lora_dim: int = 4,
alpha: float = 1,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
block_dims: Optional[List[int]] = None,
block_alphas: Optional[List[float]] = None,
conv_block_dims: Optional[List[int]] = None,
conv_block_alphas: Optional[List[float]] = None,
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
module_class: Type[object] = LoRAModule,
varbose: Optional[bool] = False,
train_text_encoder: Optional[bool] = True,
train_unet: Optional[bool] = True,
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
@@ -75,8 +91,21 @@ class LoRASpecialNetwork(LoRANetwork):
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
# create module instances
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER
def create_modules(
is_unet: bool,
text_encoder_idx: Optional[int], # None, 1, 2
root_module: torch.nn.Module,
target_replace_modules: List[torch.nn.Module],
) -> List[LoRAModule]:
prefix = (
self.LORA_PREFIX_UNET
if is_unet
else (
self.LORA_PREFIX_TEXT_ENCODER
if text_encoder_idx is None
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
)
)
loras = []
skipped = []
for name, module in root_module.named_modules():
@@ -92,11 +121,14 @@ class LoRASpecialNetwork(LoRANetwork):
dim = None
alpha = None
if modules_dim is not None:
# モジュール指定あり
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
elif is_unet and block_dims is not None:
# U-Netでblock_dims指定あり
block_idx = get_block_index(lora_name)
if is_linear or is_conv2d_1x1:
dim = block_dims[block_idx]
@@ -105,6 +137,7 @@ class LoRASpecialNetwork(LoRANetwork):
dim = conv_block_dims[block_idx]
alpha = conv_block_alphas[block_idx]
else:
# 通常、すべて対象とする
if is_linear or is_conv2d_1x1:
dim = self.lora_dim
alpha = self.alpha
@@ -113,6 +146,7 @@ class LoRASpecialNetwork(LoRANetwork):
alpha = self.conv_alpha
if dim is None or dim == 0:
# skipした情報を出力
if is_linear or is_conv2d_1x1 or (
self.conv_lora_dim is not None or conv_block_dims is not None):
skipped.append(lora_name)
@@ -131,8 +165,25 @@ class LoRASpecialNetwork(LoRANetwork):
loras.append(lora)
return loras, skipped
self.text_encoder_loras, skipped_te = create_modules(False, text_encoder,
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
# create LoRA for text encoder
# 毎回すべてのモジュールを作るのは無駄なので要検討
self.text_encoder_loras = []
skipped_te = []
if train_text_encoder:
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
index = i + 1
print(f"create LoRA for Text Encoder {index}:")
else:
index = None
print(f"create LoRA for Text Encoder:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder,
LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
@@ -140,7 +191,11 @@ class LoRASpecialNetwork(LoRANetwork):
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras, skipped_un = create_modules(True, unet, target_modules)
if train_unet:
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
else:
self.unet_loras = []
skipped_un = []
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
skipped = skipped_te + skipped_un
@@ -159,8 +214,7 @@ class LoRASpecialNetwork(LoRANetwork):
# assertion
names = set()
for lora in self.text_encoder_loras + self.unet_loras:
# doesnt work on new diffusers. TODO make sure we are not missing something
# assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
def save_weights(self, file, dtype, metadata):

View File

@@ -0,0 +1,63 @@
from typing import Union
import sys
import os
from toolkit.paths import REPOS_ROOT
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from leco import train_util
import torch
class PromptEmbeds:
text_embeds: torch.FloatTensor
pooled_embeds: Union[torch.FloatTensor, None]
def __init__(self, args) -> None:
if isinstance(args, list) or isinstance(args, tuple):
# xl
self.text_embeds = args[0]
self.pooled_embeds = args[1]
else:
# sdv1.x, sdv2.x
self.text_embeds = args
self.pooled_embeds = None
class StableDiffusion:
def __init__(
self,
vae,
tokenizer,
text_encoder,
unet,
noise_scheduler,
is_xl=False
):
# text encoder has a list of 2 for xl
self.vae = vae
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.unet = unet
self.noise_scheduler = noise_scheduler
self.is_xl = is_xl
def encode_prompt(self, prompt, num_images_per_prompt=1) -> PromptEmbeds:
prompt = prompt
# if it is not a list, make it one
if not isinstance(prompt, list):
prompt = [prompt]
if self.is_xl:
return PromptEmbeds(
train_util.encode_prompts_xl(
self.tokenizer,
self.text_encoder,
prompt,
num_images_per_prompt=num_images_per_prompt,
)
)
else:
return PromptEmbeds(
train_util.encode_prompts(
self.tokenizer, self.text_encoder, prompt
)
)

View File

@@ -21,6 +21,8 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel
import torch
import re
from toolkit.stable_diffusion_model import PromptEmbeds
SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120
SCHEDULER_TIMESTEPS = 1000
@@ -377,3 +379,19 @@ def apply_noise_offset(noise, noise_offset):
return noise
noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device)
return noise
def concat_prompt_embeddings(
unconditional: PromptEmbeds,
conditional: PromptEmbeds,
n_imgs: int,
):
text_embeds = torch.cat(
[unconditional.text_embeds, conditional.text_embeds]
).repeat_interleave(n_imgs, dim=0)
pooled_embeds = None
if unconditional.pooled_embeds is not None and conditional.pooled_embeds is not None:
pooled_embeds = torch.cat(
[unconditional.pooled_embeds, conditional.pooled_embeds]
).repeat_interleave(n_imgs, dim=0)
return PromptEmbeds([text_embeds, pooled_embeds])