mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
SDXL should be working, but I broke something where it is not converging.
This commit is contained in:
@@ -17,6 +17,7 @@ class BaseJob:
|
||||
raise ValueError('config is required')
|
||||
|
||||
self.config = config['config']
|
||||
self.raw_config = config
|
||||
self.job = config['job']
|
||||
self.name = self.get_conf('name', required=True)
|
||||
if 'meta' in config:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from jobs import BaseJob
|
||||
@@ -6,7 +7,7 @@ from collections import OrderedDict
|
||||
from typing import List
|
||||
from jobs.process import BaseExtractProcess, TrainFineTuneProcess
|
||||
from datetime import datetime
|
||||
|
||||
import yaml
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
|
||||
import sys
|
||||
@@ -16,6 +17,7 @@ sys.path.append(REPOS_ROOT)
|
||||
process_dict = {
|
||||
'vae': 'TrainVAEProcess',
|
||||
'slider': 'TrainSliderProcess',
|
||||
'lora_hack': 'TrainLoRAHack',
|
||||
}
|
||||
|
||||
|
||||
@@ -37,6 +39,13 @@ class TrainJob(BaseJob):
|
||||
# loads the processes from the config
|
||||
self.load_processes(process_dict)
|
||||
|
||||
def save_training_config(self):
|
||||
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
|
||||
os.makedirs(self.training_folder, exist_ok=True)
|
||||
save_dif = os.path.join(self.training_folder, f'run_config_{timestamp}.yaml')
|
||||
with open(save_dif, 'w') as f:
|
||||
yaml.dump(self.raw_config, f)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
print("")
|
||||
|
||||
@@ -2,7 +2,6 @@ import time
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
|
||||
from leco.train_util import predict_noise
|
||||
from toolkit.kohya_model_util import load_vae
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.optimizer import get_optimizer
|
||||
@@ -12,7 +11,7 @@ import sys
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
@@ -24,6 +23,7 @@ from tqdm import tqdm
|
||||
|
||||
from leco import train_util, model_util
|
||||
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -35,15 +35,6 @@ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも
|
||||
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
def __init__(self, vae, tokenizer, text_encoder, unet, noise_scheduler):
|
||||
self.vae = vae
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.unet = unet
|
||||
self.noise_scheduler = noise_scheduler
|
||||
|
||||
|
||||
class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
@@ -80,26 +71,44 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
original_device_dict = {
|
||||
'vae': self.sd.vae.device,
|
||||
'unet': self.sd.unet.device,
|
||||
'text_encoder': self.sd.text_encoder.device,
|
||||
# 'tokenizer': self.sd.tokenizer.device,
|
||||
}
|
||||
|
||||
# handle sdxl text encoder
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for encoder, i in zip(self.sd.text_encoder, range(len(self.sd.text_encoder))):
|
||||
original_device_dict[f'text_encoder_{i}'] = encoder.device
|
||||
encoder.to(self.device_torch)
|
||||
else:
|
||||
original_device_dict['text_encoder'] = self.sd.text_encoder.device
|
||||
self.sd.text_encoder.to(self.device_torch)
|
||||
|
||||
self.sd.vae.to(self.device_torch)
|
||||
self.sd.unet.to(self.device_torch)
|
||||
self.sd.text_encoder.to(self.device_torch)
|
||||
# self.sd.text_encoder.to(self.device_torch)
|
||||
# self.sd.tokenizer.to(self.device_torch)
|
||||
# TODO add clip skip
|
||||
|
||||
pipeline = StableDiffusionPipeline(
|
||||
vae=self.sd.vae,
|
||||
unet=self.sd.unet,
|
||||
text_encoder=self.sd.text_encoder,
|
||||
tokenizer=self.sd.tokenizer,
|
||||
scheduler=self.sd.noise_scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
if self.sd.is_xl:
|
||||
pipeline = StableDiffusionXLPipeline(
|
||||
vae=self.sd.vae,
|
||||
unet=self.sd.unet,
|
||||
text_encoder=self.sd.text_encoder[0],
|
||||
text_encoder_2=self.sd.text_encoder[1],
|
||||
tokenizer=self.sd.tokenizer[0],
|
||||
tokenizer_2=self.sd.tokenizer[1],
|
||||
scheduler=self.sd.noise_scheduler,
|
||||
)
|
||||
else:
|
||||
pipeline = StableDiffusionPipeline(
|
||||
vae=self.sd.vae,
|
||||
unet=self.sd.unet,
|
||||
text_encoder=self.sd.text_encoder,
|
||||
tokenizer=self.sd.tokenizer,
|
||||
scheduler=self.sd.noise_scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
# disable progress bar
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
@@ -118,7 +127,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
'multiplier': self.network.multiplier,
|
||||
})
|
||||
|
||||
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}", leave=False):
|
||||
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}",
|
||||
leave=False):
|
||||
raw_prompt = self.sample_config.prompts[i]
|
||||
|
||||
neg = self.sample_config.neg
|
||||
@@ -180,7 +190,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
self.sd.vae.to(original_device_dict['vae'])
|
||||
self.sd.unet.to(original_device_dict['unet'])
|
||||
self.sd.text_encoder.to(original_device_dict['text_encoder'])
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for encoder, i in zip(self.sd.text_encoder, range(len(self.sd.text_encoder))):
|
||||
encoder.to(original_device_dict[f'text_encoder_{i}'])
|
||||
else:
|
||||
self.sd.text_encoder.to(original_device_dict['text_encoder'])
|
||||
if self.network is not None:
|
||||
self.network.train()
|
||||
self.network.multiplier = start_multiplier
|
||||
@@ -267,23 +281,90 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# return loss
|
||||
return 0.0
|
||||
|
||||
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
|
||||
def diffuse_some_steps(
|
||||
def get_time_ids_from_latents(self, latents):
|
||||
bs, ch, h, w = list(latents.shape)
|
||||
|
||||
height = h * VAE_SCALE_FACTOR
|
||||
width = w * VAE_SCALE_FACTOR
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
if self.sd.is_xl:
|
||||
prompt_ids = train_util.get_add_time_ids(
|
||||
height,
|
||||
width,
|
||||
dynamic_crops=False, # look into this
|
||||
dtype=dtype,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
return train_util.concat_embeddings(
|
||||
prompt_ids, prompt_ids, bs
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
def predict_noise(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
text_embeddings: torch.FloatTensor,
|
||||
total_timesteps: int = 1000,
|
||||
start_timesteps=0,
|
||||
text_embeddings: PromptEmbeds,
|
||||
timestep: int,
|
||||
guidance_scale=7.5,
|
||||
guidance_rescale=0.7,
|
||||
add_time_ids=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
for timestep in tqdm(self.sd.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False):
|
||||
noise_pred = train_util.predict_noise(
|
||||
self.sd.unet, self.sd.noise_scheduler, timestep, latents, text_embeddings, **kwargs
|
||||
if self.sd.is_xl:
|
||||
if add_time_ids is None:
|
||||
add_time_ids = self.get_time_ids_from_latents(latents)
|
||||
# todo LECOs code looks like it is omitting noise_pred
|
||||
noise_pred = train_util.predict_noise_xl(
|
||||
self.sd.unet,
|
||||
self.sd.noise_scheduler,
|
||||
timestep,
|
||||
latents,
|
||||
text_embeddings.text_embeds,
|
||||
text_embeddings.pooled_embeds,
|
||||
add_time_ids,
|
||||
guidance_scale=guidance_scale,
|
||||
guidance_rescale=guidance_rescale
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||
else:
|
||||
noise_pred = train_util.predict_noise(
|
||||
self.sd.unet,
|
||||
self.sd.noise_scheduler,
|
||||
timestep,
|
||||
latents,
|
||||
text_embeddings.text_embeds,
|
||||
guidance_scale=guidance_scale
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||
return latents
|
||||
|
||||
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
|
||||
def diffuse_some_steps(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
text_embeddings: PromptEmbeds,
|
||||
total_timesteps: int = 1000,
|
||||
start_timesteps=0,
|
||||
guidance_scale=1,
|
||||
add_time_ids=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
for timestep in tqdm(self.sd.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False):
|
||||
latents = self.predict_noise(
|
||||
latents,
|
||||
text_embeddings,
|
||||
timestep,
|
||||
guidance_scale=guidance_scale,
|
||||
add_time_ids=add_time_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# return latents_steps
|
||||
return latents
|
||||
@@ -296,20 +377,35 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models(
|
||||
self.model_config.name_or_path,
|
||||
scheduler_name=self.train_config.noise_scheduler,
|
||||
v2=self.model_config.is_v2,
|
||||
v_pred=self.model_config.is_v_pred,
|
||||
)
|
||||
if self.model_config.is_xl:
|
||||
tokenizer, text_encoders, unet, noise_scheduler = model_util.load_models_xl(
|
||||
self.model_config.name_or_path,
|
||||
scheduler_name=self.train_config.noise_scheduler,
|
||||
weight_dtype=dtype,
|
||||
)
|
||||
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
|
||||
text_encoder = text_encoders
|
||||
else:
|
||||
tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models(
|
||||
self.model_config.name_or_path,
|
||||
scheduler_name=self.train_config.noise_scheduler,
|
||||
v2=self.model_config.is_v2,
|
||||
v_pred=self.model_config.is_v_pred,
|
||||
)
|
||||
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
text_encoder.eval()
|
||||
|
||||
# just for now or of we want to load a custom one
|
||||
# put on cpu for now, we only need it when sampling
|
||||
vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype)
|
||||
vae.eval()
|
||||
self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler)
|
||||
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
text_encoder.eval()
|
||||
self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler, is_xl=self.model_config.is_xl)
|
||||
|
||||
unet.to(self.device_torch, dtype=dtype)
|
||||
if self.train_config.xformers:
|
||||
@@ -323,7 +419,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
unet=unet,
|
||||
lora_dim=self.network_config.rank,
|
||||
multiplier=1.0,
|
||||
alpha=self.network_config.alpha
|
||||
alpha=self.network_config.alpha,
|
||||
train_unet=self.train_config.train_unet,
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
)
|
||||
|
||||
self.network.force_to(self.device_torch, dtype=dtype)
|
||||
@@ -376,8 +474,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.hook_before_train_loop()
|
||||
|
||||
# sample first
|
||||
self.print("Generating baseline samples before training")
|
||||
self.sample(0)
|
||||
if self.train_config.skip_first_sample:
|
||||
self.print("Skipping first sample due to config setting")
|
||||
else:
|
||||
self.print("Generating baseline samples before training")
|
||||
self.sample(0)
|
||||
|
||||
self.progress_bar = tqdm(
|
||||
total=self.train_config.steps,
|
||||
|
||||
76
jobs/process/TrainLoRAHack.py
Normal file
76
jobs/process/TrainLoRAHack.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# ref:
|
||||
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
|
||||
from toolkit.config_modules import SliderConfig
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
import sys
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import gc
|
||||
|
||||
import torch
|
||||
from leco import train_util, model_util
|
||||
from leco.prompt_util import PromptEmbedsCache
|
||||
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class LoRAHack:
|
||||
def __init__(self, **kwargs):
|
||||
self.type = kwargs.get('type', 'suppression')
|
||||
|
||||
|
||||
class TrainLoRAHack(BaseSDTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
self.hack_config = LoRAHack(**self.get_conf('hack', {}))
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
# we don't need text encoder so move it to cpu
|
||||
self.sd.text_encoder.to("cpu")
|
||||
flush()
|
||||
# end hook_before_train_loop
|
||||
|
||||
if self.hack_config.type == 'suppression':
|
||||
# set all params to self.current_suppression
|
||||
params = self.network.parameters()
|
||||
for param in params:
|
||||
# get random noise for each param
|
||||
noise = torch.randn_like(param) - 0.5
|
||||
# apply noise to param
|
||||
param.data = noise * 0.001
|
||||
|
||||
|
||||
def supress_loop(self):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
|
||||
loss_dict = OrderedDict(
|
||||
{'sup': 0.0}
|
||||
)
|
||||
# increase noise
|
||||
for param in self.network.parameters():
|
||||
# get random noise for each param
|
||||
noise = torch.randn_like(param) - 0.5
|
||||
# apply noise to param
|
||||
param.data = param.data + noise * 0.001
|
||||
|
||||
|
||||
|
||||
return loss_dict
|
||||
|
||||
def hook_train_loop(self):
|
||||
if self.hack_config.type == 'suppression':
|
||||
return self.supress_loop()
|
||||
else:
|
||||
raise NotImplementedError(f'unknown hack type: {self.hack_config.type}')
|
||||
# end hook_train_loop
|
||||
@@ -3,19 +3,22 @@
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from toolkit.config_modules import SliderConfig
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
import sys
|
||||
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import gc
|
||||
from toolkit import train_tools
|
||||
|
||||
import torch
|
||||
from leco import train_util, model_util
|
||||
from leco.prompt_util import PromptEmbedsCache
|
||||
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
|
||||
|
||||
|
||||
@@ -29,7 +32,6 @@ def flush():
|
||||
gc.collect()
|
||||
|
||||
|
||||
|
||||
class EncodedPromptPair:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -54,6 +56,19 @@ class EncodedPromptPair:
|
||||
self.weight = weight
|
||||
|
||||
|
||||
class PromptEmbedsCache: # 使いまわしたいので
|
||||
prompts: dict[str, PromptEmbeds] = {}
|
||||
|
||||
def __setitem__(self, __name: str, __value: PromptEmbeds) -> None:
|
||||
self.prompts[__name] = __value
|
||||
|
||||
def __getitem__(self, __name: str) -> Optional[PromptEmbeds]:
|
||||
if __name in self.prompts:
|
||||
return self.prompts[__name]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class EncodedAnchor:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -89,19 +104,17 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
with torch.no_grad():
|
||||
neutral = ""
|
||||
for target in self.slider_config.targets:
|
||||
# build the cache
|
||||
for prompt in [
|
||||
target.target_class,
|
||||
target.positive,
|
||||
target.negative,
|
||||
neutral # empty neutral
|
||||
]:
|
||||
if cache[prompt] is None:
|
||||
cache[prompt] = self.sd.encode_prompt(prompt)
|
||||
for resolution in self.slider_config.resolutions:
|
||||
width, height = resolution
|
||||
# build the cache
|
||||
for prompt in [
|
||||
target.target_class,
|
||||
target.positive,
|
||||
target.negative,
|
||||
neutral # empty neutral
|
||||
]:
|
||||
if cache[prompt] == None:
|
||||
cache[prompt] = train_util.encode_prompts(
|
||||
self.sd.tokenizer, self.sd.text_encoder, [prompt]
|
||||
)
|
||||
only_erase = len(target.positive.strip()) == 0
|
||||
only_enhance = len(target.negative.strip()) == 0
|
||||
|
||||
@@ -184,9 +197,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
anchor.neg_prompt # empty neutral
|
||||
]:
|
||||
if cache[prompt] == None:
|
||||
cache[prompt] = train_util.encode_prompts(
|
||||
self.sd.tokenizer, self.sd.text_encoder, [prompt]
|
||||
)
|
||||
cache[prompt] = self.sd.encode_prompt(prompt)
|
||||
|
||||
anchor_pairs += [
|
||||
EncodedAnchor(
|
||||
@@ -198,7 +209,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
|
||||
# move to cpu to save vram
|
||||
# We don't need text encoder anymore, but keep it on cpu for sampling
|
||||
self.sd.text_encoder.to("cpu")
|
||||
# if text encoder is list
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for encoder in self.sd.text_encoder:
|
||||
encoder.to("cpu")
|
||||
else:
|
||||
self.sd.text_encoder.to("cpu")
|
||||
self.prompt_cache = cache
|
||||
self.prompt_pairs = prompt_pairs
|
||||
self.anchor_pairs = anchor_pairs
|
||||
@@ -220,6 +236,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
negative = prompt_pair.negative
|
||||
positive = prompt_pair.positive
|
||||
weight = prompt_pair.weight
|
||||
multiplier = prompt_pair.multiplier
|
||||
|
||||
unet = self.sd.unet
|
||||
noise_scheduler = self.sd.noise_scheduler
|
||||
@@ -227,8 +244,20 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
lr_scheduler = self.lr_scheduler
|
||||
loss_function = torch.nn.MSELoss()
|
||||
|
||||
def get_noise_pred(p, n):
|
||||
return self.predict_noise(
|
||||
latents=denoised_latents,
|
||||
text_embeddings=train_tools.concat_prompt_embeddings(
|
||||
p, # unconditional
|
||||
n, # positive
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
timestep=current_timestep,
|
||||
guidance_scale=1,
|
||||
)
|
||||
|
||||
# set network multiplier
|
||||
self.network.multiplier = prompt_pair.multiplier
|
||||
self.network.multiplier = multiplier
|
||||
|
||||
with torch.no_grad():
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
@@ -254,9 +283,10 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
|
||||
with self.network:
|
||||
assert self.network.is_active
|
||||
self.network.multiplier = multiplier
|
||||
denoised_latents = self.diffuse_some_steps(
|
||||
latents, # pass simple noise latents
|
||||
train_util.concat_embeddings(
|
||||
train_tools.concat_prompt_embeddings(
|
||||
positive, # unconditional
|
||||
target_class, # target
|
||||
self.train_config.batch_size,
|
||||
@@ -272,43 +302,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
|
||||
]
|
||||
|
||||
# with network: 0 weight LoRA is enabled outside "with network:"
|
||||
positive_latents = train_util.predict_noise( # positive_latents
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
train_util.concat_embeddings(
|
||||
positive, # unconditional
|
||||
negative, # positive
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
guidance_scale=1,
|
||||
).to("cpu", dtype=torch.float32)
|
||||
neutral_latents = train_util.predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
train_util.concat_embeddings(
|
||||
positive, # unconditional
|
||||
neutral, # neutral
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
guidance_scale=1,
|
||||
).to("cpu", dtype=torch.float32)
|
||||
unconditional_latents = train_util.predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
train_util.concat_embeddings(
|
||||
positive, # unconditional
|
||||
positive, # unconditional
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
guidance_scale=1,
|
||||
).to("cpu", dtype=torch.float32)
|
||||
positive_latents = get_noise_pred(positive, negative)
|
||||
|
||||
neutral_latents = get_noise_pred(positive, neutral)
|
||||
|
||||
unconditional_latents = get_noise_pred(positive, positive)
|
||||
|
||||
anchor_loss = None
|
||||
if len(self.anchor_pairs) > 0:
|
||||
@@ -317,51 +315,19 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
torch.randint(0, len(self.anchor_pairs), (1,)).item()
|
||||
]
|
||||
with torch.no_grad():
|
||||
anchor_target_noise = train_util.predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
train_util.concat_embeddings(
|
||||
anchor.prompt,
|
||||
anchor.neg_prompt,
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
guidance_scale=1,
|
||||
).to("cpu", dtype=torch.float32)
|
||||
anchor_target_noise = get_noise_pred(anchor.prompt, anchor.neg_prompt)
|
||||
with self.network:
|
||||
# anchor whatever weight prompt pair is using
|
||||
pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0
|
||||
self.network.multiplier = anchor.multiplier * pos_nem_mult
|
||||
anchor_pred_noise = train_util.predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
train_util.concat_embeddings(
|
||||
anchor.prompt,
|
||||
anchor.neg_prompt,
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
guidance_scale=1,
|
||||
).to("cpu", dtype=torch.float32)
|
||||
|
||||
anchor_pred_noise = get_noise_pred(anchor.prompt, anchor.neg_prompt)
|
||||
|
||||
self.network.multiplier = prompt_pair.multiplier
|
||||
|
||||
with self.network:
|
||||
self.network.multiplier = prompt_pair.multiplier
|
||||
target_latents = train_util.predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
train_util.concat_embeddings(
|
||||
positive, # unconditional
|
||||
target_class, # target
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
guidance_scale=1,
|
||||
).to("cpu", dtype=torch.float32)
|
||||
target_latents = get_noise_pred(positive, target_class)
|
||||
|
||||
# if self.logging_config.verbose:
|
||||
# self.print("target_latents:", target_latents[0, 0, :5, :5])
|
||||
|
||||
@@ -6,3 +6,4 @@ from .BaseTrainProcess import BaseTrainProcess
|
||||
from .TrainVAEProcess import TrainVAEProcess
|
||||
from .BaseMergeProcess import BaseMergeProcess
|
||||
from .TrainSliderProcess import TrainSliderProcess
|
||||
from .TrainLoRAHack import TrainLoRAHack
|
||||
|
||||
@@ -9,4 +9,5 @@ accelerator
|
||||
pyyaml
|
||||
oyaml
|
||||
tensorboard
|
||||
kornia
|
||||
kornia
|
||||
invisible-watermark
|
||||
@@ -50,12 +50,14 @@ class TrainConfig:
|
||||
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
||||
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.name_or_path: str = kwargs.get('name_or_path', None)
|
||||
self.is_v2: bool = kwargs.get('is_v2', False)
|
||||
self.is_xl: bool = kwargs.get('is_xl', False)
|
||||
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
|
||||
|
||||
if self.name_or_path is None:
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
from typing import List, Optional, Dict, Type, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from .paths import SD_SCRIPTS_ROOT
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
@@ -14,26 +16,40 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
_multiplier: float = 1.0
|
||||
is_active: bool = False
|
||||
|
||||
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
||||
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
|
||||
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder,
|
||||
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
||||
unet,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
conv_lora_dim=None,
|
||||
conv_alpha=None,
|
||||
block_dims=None,
|
||||
block_alphas=None,
|
||||
conv_block_dims=None,
|
||||
conv_block_alphas=None,
|
||||
modules_dim=None,
|
||||
modules_alpha=None,
|
||||
module_class=LoRAModule,
|
||||
varbose=False,
|
||||
multiplier: float = 1.0,
|
||||
lora_dim: int = 4,
|
||||
alpha: float = 1,
|
||||
dropout: Optional[float] = None,
|
||||
rank_dropout: Optional[float] = None,
|
||||
module_dropout: Optional[float] = None,
|
||||
conv_lora_dim: Optional[int] = None,
|
||||
conv_alpha: Optional[float] = None,
|
||||
block_dims: Optional[List[int]] = None,
|
||||
block_alphas: Optional[List[float]] = None,
|
||||
conv_block_dims: Optional[List[int]] = None,
|
||||
conv_block_alphas: Optional[List[float]] = None,
|
||||
modules_dim: Optional[Dict[str, int]] = None,
|
||||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
module_class: Type[object] = LoRAModule,
|
||||
varbose: Optional[bool] = False,
|
||||
train_text_encoder: Optional[bool] = True,
|
||||
train_unet: Optional[bool] = True,
|
||||
) -> None:
|
||||
"""
|
||||
LoRA network: すごく引数が多いが、パターンは以下の通り
|
||||
@@ -75,8 +91,21 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||
|
||||
# create module instances
|
||||
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
||||
prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
||||
def create_modules(
|
||||
is_unet: bool,
|
||||
text_encoder_idx: Optional[int], # None, 1, 2
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules: List[torch.nn.Module],
|
||||
) -> List[LoRAModule]:
|
||||
prefix = (
|
||||
self.LORA_PREFIX_UNET
|
||||
if is_unet
|
||||
else (
|
||||
self.LORA_PREFIX_TEXT_ENCODER
|
||||
if text_encoder_idx is None
|
||||
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
|
||||
)
|
||||
)
|
||||
loras = []
|
||||
skipped = []
|
||||
for name, module in root_module.named_modules():
|
||||
@@ -92,11 +121,14 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
|
||||
dim = None
|
||||
alpha = None
|
||||
|
||||
if modules_dim is not None:
|
||||
# モジュール指定あり
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha = modules_alpha[lora_name]
|
||||
elif is_unet and block_dims is not None:
|
||||
# U-Netでblock_dims指定あり
|
||||
block_idx = get_block_index(lora_name)
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = block_dims[block_idx]
|
||||
@@ -105,6 +137,7 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
dim = conv_block_dims[block_idx]
|
||||
alpha = conv_block_alphas[block_idx]
|
||||
else:
|
||||
# 通常、すべて対象とする
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = self.lora_dim
|
||||
alpha = self.alpha
|
||||
@@ -113,6 +146,7 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
alpha = self.conv_alpha
|
||||
|
||||
if dim is None or dim == 0:
|
||||
# skipした情報を出力
|
||||
if is_linear or is_conv2d_1x1 or (
|
||||
self.conv_lora_dim is not None or conv_block_dims is not None):
|
||||
skipped.append(lora_name)
|
||||
@@ -131,8 +165,25 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
loras.append(lora)
|
||||
return loras, skipped
|
||||
|
||||
self.text_encoder_loras, skipped_te = create_modules(False, text_encoder,
|
||||
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
||||
|
||||
# create LoRA for text encoder
|
||||
# 毎回すべてのモジュールを作るのは無駄なので要検討
|
||||
self.text_encoder_loras = []
|
||||
skipped_te = []
|
||||
if train_text_encoder:
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if len(text_encoders) > 1:
|
||||
index = i + 1
|
||||
print(f"create LoRA for Text Encoder {index}:")
|
||||
else:
|
||||
index = None
|
||||
print(f"create LoRA for Text Encoder:")
|
||||
|
||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder,
|
||||
LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
skipped_te += skipped
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||
@@ -140,7 +191,11 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
||||
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_loras, skipped_un = create_modules(True, unet, target_modules)
|
||||
if train_unet:
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
else:
|
||||
self.unet_loras = []
|
||||
skipped_un = []
|
||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
skipped = skipped_te + skipped_un
|
||||
@@ -159,8 +214,7 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
# assertion
|
||||
names = set()
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
# doesnt work on new diffusers. TODO make sure we are not missing something
|
||||
# assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||
names.add(lora.lora_name)
|
||||
|
||||
def save_weights(self, file, dtype, metadata):
|
||||
|
||||
63
toolkit/stable_diffusion_model.py
Normal file
63
toolkit/stable_diffusion_model.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from typing import Union
|
||||
import sys
|
||||
import os
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
from leco import train_util
|
||||
import torch
|
||||
|
||||
|
||||
class PromptEmbeds:
|
||||
text_embeds: torch.FloatTensor
|
||||
pooled_embeds: Union[torch.FloatTensor, None]
|
||||
|
||||
def __init__(self, args) -> None:
|
||||
if isinstance(args, list) or isinstance(args, tuple):
|
||||
# xl
|
||||
self.text_embeds = args[0]
|
||||
self.pooled_embeds = args[1]
|
||||
else:
|
||||
# sdv1.x, sdv2.x
|
||||
self.text_embeds = args
|
||||
self.pooled_embeds = None
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
def __init__(
|
||||
self,
|
||||
vae,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
unet,
|
||||
noise_scheduler,
|
||||
is_xl=False
|
||||
):
|
||||
# text encoder has a list of 2 for xl
|
||||
self.vae = vae
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.unet = unet
|
||||
self.noise_scheduler = noise_scheduler
|
||||
self.is_xl = is_xl
|
||||
|
||||
def encode_prompt(self, prompt, num_images_per_prompt=1) -> PromptEmbeds:
|
||||
prompt = prompt
|
||||
# if it is not a list, make it one
|
||||
if not isinstance(prompt, list):
|
||||
prompt = [prompt]
|
||||
if self.is_xl:
|
||||
return PromptEmbeds(
|
||||
train_util.encode_prompts_xl(
|
||||
self.tokenizer,
|
||||
self.text_encoder,
|
||||
prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return PromptEmbeds(
|
||||
train_util.encode_prompts(
|
||||
self.tokenizer, self.text_encoder, prompt
|
||||
)
|
||||
)
|
||||
@@ -21,6 +21,8 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel
|
||||
import torch
|
||||
import re
|
||||
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
|
||||
SCHEDULER_LINEAR_START = 0.00085
|
||||
SCHEDULER_LINEAR_END = 0.0120
|
||||
SCHEDULER_TIMESTEPS = 1000
|
||||
@@ -377,3 +379,19 @@ def apply_noise_offset(noise, noise_offset):
|
||||
return noise
|
||||
noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device)
|
||||
return noise
|
||||
|
||||
|
||||
def concat_prompt_embeddings(
|
||||
unconditional: PromptEmbeds,
|
||||
conditional: PromptEmbeds,
|
||||
n_imgs: int,
|
||||
):
|
||||
text_embeds = torch.cat(
|
||||
[unconditional.text_embeds, conditional.text_embeds]
|
||||
).repeat_interleave(n_imgs, dim=0)
|
||||
pooled_embeds = None
|
||||
if unconditional.pooled_embeds is not None and conditional.pooled_embeds is not None:
|
||||
pooled_embeds = torch.cat(
|
||||
[unconditional.pooled_embeds, conditional.pooled_embeds]
|
||||
).repeat_interleave(n_imgs, dim=0)
|
||||
return PromptEmbeds([text_embeds, pooled_embeds])
|
||||
|
||||
Reference in New Issue
Block a user