mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
SDXL should be working, but I broke something where it is not converging.
This commit is contained in:
@@ -17,6 +17,7 @@ class BaseJob:
|
|||||||
raise ValueError('config is required')
|
raise ValueError('config is required')
|
||||||
|
|
||||||
self.config = config['config']
|
self.config = config['config']
|
||||||
|
self.raw_config = config
|
||||||
self.job = config['job']
|
self.job = config['job']
|
||||||
self.name = self.get_conf('name', required=True)
|
self.name = self.get_conf('name', required=True)
|
||||||
if 'meta' in config:
|
if 'meta' in config:
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from jobs import BaseJob
|
from jobs import BaseJob
|
||||||
@@ -6,7 +7,7 @@ from collections import OrderedDict
|
|||||||
from typing import List
|
from typing import List
|
||||||
from jobs.process import BaseExtractProcess, TrainFineTuneProcess
|
from jobs.process import BaseExtractProcess, TrainFineTuneProcess
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import yaml
|
||||||
from toolkit.paths import REPOS_ROOT
|
from toolkit.paths import REPOS_ROOT
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
@@ -16,6 +17,7 @@ sys.path.append(REPOS_ROOT)
|
|||||||
process_dict = {
|
process_dict = {
|
||||||
'vae': 'TrainVAEProcess',
|
'vae': 'TrainVAEProcess',
|
||||||
'slider': 'TrainSliderProcess',
|
'slider': 'TrainSliderProcess',
|
||||||
|
'lora_hack': 'TrainLoRAHack',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -37,6 +39,13 @@ class TrainJob(BaseJob):
|
|||||||
# loads the processes from the config
|
# loads the processes from the config
|
||||||
self.load_processes(process_dict)
|
self.load_processes(process_dict)
|
||||||
|
|
||||||
|
def save_training_config(self):
|
||||||
|
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
|
||||||
|
os.makedirs(self.training_folder, exist_ok=True)
|
||||||
|
save_dif = os.path.join(self.training_folder, f'run_config_{timestamp}.yaml')
|
||||||
|
with open(save_dif, 'w') as f:
|
||||||
|
yaml.dump(self.raw_config, f)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
super().run()
|
super().run()
|
||||||
print("")
|
print("")
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import time
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from leco.train_util import predict_noise
|
|
||||||
from toolkit.kohya_model_util import load_vae
|
from toolkit.kohya_model_util import load_vae
|
||||||
from toolkit.lora_special import LoRASpecialNetwork
|
from toolkit.lora_special import LoRASpecialNetwork
|
||||||
from toolkit.optimizer import get_optimizer
|
from toolkit.optimizer import get_optimizer
|
||||||
@@ -12,7 +11,7 @@ import sys
|
|||||||
sys.path.append(REPOS_ROOT)
|
sys.path.append(REPOS_ROOT)
|
||||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||||
|
|
||||||
from diffusers import StableDiffusionPipeline
|
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||||
|
|
||||||
from jobs.process import BaseTrainProcess
|
from jobs.process import BaseTrainProcess
|
||||||
from toolkit.metadata import get_meta_for_safetensors
|
from toolkit.metadata import get_meta_for_safetensors
|
||||||
@@ -24,6 +23,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from leco import train_util, model_util
|
from leco import train_util, model_util
|
||||||
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig
|
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig
|
||||||
|
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
|
||||||
|
|
||||||
|
|
||||||
def flush():
|
def flush():
|
||||||
@@ -35,15 +35,6 @@ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも
|
|||||||
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion:
|
|
||||||
def __init__(self, vae, tokenizer, text_encoder, unet, noise_scheduler):
|
|
||||||
self.vae = vae
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.text_encoder = text_encoder
|
|
||||||
self.unet = unet
|
|
||||||
self.noise_scheduler = noise_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSDTrainProcess(BaseTrainProcess):
|
class BaseSDTrainProcess(BaseTrainProcess):
|
||||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||||
super().__init__(process_id, job, config)
|
super().__init__(process_id, job, config)
|
||||||
@@ -80,26 +71,44 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
original_device_dict = {
|
original_device_dict = {
|
||||||
'vae': self.sd.vae.device,
|
'vae': self.sd.vae.device,
|
||||||
'unet': self.sd.unet.device,
|
'unet': self.sd.unet.device,
|
||||||
'text_encoder': self.sd.text_encoder.device,
|
|
||||||
# 'tokenizer': self.sd.tokenizer.device,
|
# 'tokenizer': self.sd.tokenizer.device,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# handle sdxl text encoder
|
||||||
|
if isinstance(self.sd.text_encoder, list):
|
||||||
|
for encoder, i in zip(self.sd.text_encoder, range(len(self.sd.text_encoder))):
|
||||||
|
original_device_dict[f'text_encoder_{i}'] = encoder.device
|
||||||
|
encoder.to(self.device_torch)
|
||||||
|
else:
|
||||||
|
original_device_dict['text_encoder'] = self.sd.text_encoder.device
|
||||||
|
self.sd.text_encoder.to(self.device_torch)
|
||||||
|
|
||||||
self.sd.vae.to(self.device_torch)
|
self.sd.vae.to(self.device_torch)
|
||||||
self.sd.unet.to(self.device_torch)
|
self.sd.unet.to(self.device_torch)
|
||||||
self.sd.text_encoder.to(self.device_torch)
|
# self.sd.text_encoder.to(self.device_torch)
|
||||||
# self.sd.tokenizer.to(self.device_torch)
|
# self.sd.tokenizer.to(self.device_torch)
|
||||||
# TODO add clip skip
|
# TODO add clip skip
|
||||||
|
if self.sd.is_xl:
|
||||||
pipeline = StableDiffusionPipeline(
|
pipeline = StableDiffusionXLPipeline(
|
||||||
vae=self.sd.vae,
|
vae=self.sd.vae,
|
||||||
unet=self.sd.unet,
|
unet=self.sd.unet,
|
||||||
text_encoder=self.sd.text_encoder,
|
text_encoder=self.sd.text_encoder[0],
|
||||||
tokenizer=self.sd.tokenizer,
|
text_encoder_2=self.sd.text_encoder[1],
|
||||||
scheduler=self.sd.noise_scheduler,
|
tokenizer=self.sd.tokenizer[0],
|
||||||
safety_checker=None,
|
tokenizer_2=self.sd.tokenizer[1],
|
||||||
feature_extractor=None,
|
scheduler=self.sd.noise_scheduler,
|
||||||
requires_safety_checker=False,
|
)
|
||||||
)
|
else:
|
||||||
|
pipeline = StableDiffusionPipeline(
|
||||||
|
vae=self.sd.vae,
|
||||||
|
unet=self.sd.unet,
|
||||||
|
text_encoder=self.sd.text_encoder,
|
||||||
|
tokenizer=self.sd.tokenizer,
|
||||||
|
scheduler=self.sd.noise_scheduler,
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
)
|
||||||
# disable progress bar
|
# disable progress bar
|
||||||
pipeline.set_progress_bar_config(disable=True)
|
pipeline.set_progress_bar_config(disable=True)
|
||||||
|
|
||||||
@@ -118,7 +127,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
'multiplier': self.network.multiplier,
|
'multiplier': self.network.multiplier,
|
||||||
})
|
})
|
||||||
|
|
||||||
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}", leave=False):
|
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}",
|
||||||
|
leave=False):
|
||||||
raw_prompt = self.sample_config.prompts[i]
|
raw_prompt = self.sample_config.prompts[i]
|
||||||
|
|
||||||
neg = self.sample_config.neg
|
neg = self.sample_config.neg
|
||||||
@@ -180,7 +190,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
self.sd.vae.to(original_device_dict['vae'])
|
self.sd.vae.to(original_device_dict['vae'])
|
||||||
self.sd.unet.to(original_device_dict['unet'])
|
self.sd.unet.to(original_device_dict['unet'])
|
||||||
self.sd.text_encoder.to(original_device_dict['text_encoder'])
|
if isinstance(self.sd.text_encoder, list):
|
||||||
|
for encoder, i in zip(self.sd.text_encoder, range(len(self.sd.text_encoder))):
|
||||||
|
encoder.to(original_device_dict[f'text_encoder_{i}'])
|
||||||
|
else:
|
||||||
|
self.sd.text_encoder.to(original_device_dict['text_encoder'])
|
||||||
if self.network is not None:
|
if self.network is not None:
|
||||||
self.network.train()
|
self.network.train()
|
||||||
self.network.multiplier = start_multiplier
|
self.network.multiplier = start_multiplier
|
||||||
@@ -267,23 +281,90 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# return loss
|
# return loss
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
|
def get_time_ids_from_latents(self, latents):
|
||||||
def diffuse_some_steps(
|
bs, ch, h, w = list(latents.shape)
|
||||||
|
|
||||||
|
height = h * VAE_SCALE_FACTOR
|
||||||
|
width = w * VAE_SCALE_FACTOR
|
||||||
|
|
||||||
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
|
|
||||||
|
if self.sd.is_xl:
|
||||||
|
prompt_ids = train_util.get_add_time_ids(
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
dynamic_crops=False, # look into this
|
||||||
|
dtype=dtype,
|
||||||
|
).to(self.device_torch, dtype=dtype)
|
||||||
|
return train_util.concat_embeddings(
|
||||||
|
prompt_ids, prompt_ids, bs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def predict_noise(
|
||||||
self,
|
self,
|
||||||
latents: torch.FloatTensor,
|
latents: torch.FloatTensor,
|
||||||
text_embeddings: torch.FloatTensor,
|
text_embeddings: PromptEmbeds,
|
||||||
total_timesteps: int = 1000,
|
timestep: int,
|
||||||
start_timesteps=0,
|
guidance_scale=7.5,
|
||||||
|
guidance_rescale=0.7,
|
||||||
|
add_time_ids=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if self.sd.is_xl:
|
||||||
for timestep in tqdm(self.sd.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False):
|
if add_time_ids is None:
|
||||||
noise_pred = train_util.predict_noise(
|
add_time_ids = self.get_time_ids_from_latents(latents)
|
||||||
self.sd.unet, self.sd.noise_scheduler, timestep, latents, text_embeddings, **kwargs
|
# todo LECOs code looks like it is omitting noise_pred
|
||||||
|
noise_pred = train_util.predict_noise_xl(
|
||||||
|
self.sd.unet,
|
||||||
|
self.sd.noise_scheduler,
|
||||||
|
timestep,
|
||||||
|
latents,
|
||||||
|
text_embeddings.text_embeds,
|
||||||
|
text_embeddings.pooled_embeds,
|
||||||
|
add_time_ids,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
guidance_rescale=guidance_rescale
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample
|
latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||||
|
else:
|
||||||
|
noise_pred = train_util.predict_noise(
|
||||||
|
self.sd.unet,
|
||||||
|
self.sd.noise_scheduler,
|
||||||
|
timestep,
|
||||||
|
latents,
|
||||||
|
text_embeddings.text_embeds,
|
||||||
|
guidance_scale=guidance_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||||
|
return latents
|
||||||
|
|
||||||
|
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
|
||||||
|
def diffuse_some_steps(
|
||||||
|
self,
|
||||||
|
latents: torch.FloatTensor,
|
||||||
|
text_embeddings: PromptEmbeds,
|
||||||
|
total_timesteps: int = 1000,
|
||||||
|
start_timesteps=0,
|
||||||
|
guidance_scale=1,
|
||||||
|
add_time_ids=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
|
for timestep in tqdm(self.sd.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False):
|
||||||
|
latents = self.predict_noise(
|
||||||
|
latents,
|
||||||
|
text_embeddings,
|
||||||
|
timestep,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
add_time_ids=add_time_ids,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# return latents_steps
|
# return latents_steps
|
||||||
return latents
|
return latents
|
||||||
@@ -296,20 +377,35 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
|
|
||||||
tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models(
|
if self.model_config.is_xl:
|
||||||
self.model_config.name_or_path,
|
tokenizer, text_encoders, unet, noise_scheduler = model_util.load_models_xl(
|
||||||
scheduler_name=self.train_config.noise_scheduler,
|
self.model_config.name_or_path,
|
||||||
v2=self.model_config.is_v2,
|
scheduler_name=self.train_config.noise_scheduler,
|
||||||
v_pred=self.model_config.is_v_pred,
|
weight_dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for text_encoder in text_encoders:
|
||||||
|
text_encoder.to(self.device_torch, dtype=dtype)
|
||||||
|
text_encoder.requires_grad_(False)
|
||||||
|
text_encoder.eval()
|
||||||
|
|
||||||
|
text_encoder = text_encoders
|
||||||
|
else:
|
||||||
|
tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models(
|
||||||
|
self.model_config.name_or_path,
|
||||||
|
scheduler_name=self.train_config.noise_scheduler,
|
||||||
|
v2=self.model_config.is_v2,
|
||||||
|
v_pred=self.model_config.is_v_pred,
|
||||||
|
)
|
||||||
|
|
||||||
|
text_encoder.to(self.device_torch, dtype=dtype)
|
||||||
|
text_encoder.eval()
|
||||||
|
|
||||||
# just for now or of we want to load a custom one
|
# just for now or of we want to load a custom one
|
||||||
# put on cpu for now, we only need it when sampling
|
# put on cpu for now, we only need it when sampling
|
||||||
vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype)
|
vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler)
|
self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler, is_xl=self.model_config.is_xl)
|
||||||
|
|
||||||
text_encoder.to(self.device_torch, dtype=dtype)
|
|
||||||
text_encoder.eval()
|
|
||||||
|
|
||||||
unet.to(self.device_torch, dtype=dtype)
|
unet.to(self.device_torch, dtype=dtype)
|
||||||
if self.train_config.xformers:
|
if self.train_config.xformers:
|
||||||
@@ -323,7 +419,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
unet=unet,
|
unet=unet,
|
||||||
lora_dim=self.network_config.rank,
|
lora_dim=self.network_config.rank,
|
||||||
multiplier=1.0,
|
multiplier=1.0,
|
||||||
alpha=self.network_config.alpha
|
alpha=self.network_config.alpha,
|
||||||
|
train_unet=self.train_config.train_unet,
|
||||||
|
train_text_encoder=self.train_config.train_text_encoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.network.force_to(self.device_torch, dtype=dtype)
|
self.network.force_to(self.device_torch, dtype=dtype)
|
||||||
@@ -376,8 +474,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
self.hook_before_train_loop()
|
self.hook_before_train_loop()
|
||||||
|
|
||||||
# sample first
|
# sample first
|
||||||
self.print("Generating baseline samples before training")
|
if self.train_config.skip_first_sample:
|
||||||
self.sample(0)
|
self.print("Skipping first sample due to config setting")
|
||||||
|
else:
|
||||||
|
self.print("Generating baseline samples before training")
|
||||||
|
self.sample(0)
|
||||||
|
|
||||||
self.progress_bar = tqdm(
|
self.progress_bar = tqdm(
|
||||||
total=self.train_config.steps,
|
total=self.train_config.steps,
|
||||||
|
|||||||
76
jobs/process/TrainLoRAHack.py
Normal file
76
jobs/process/TrainLoRAHack.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
# ref:
|
||||||
|
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
|
||||||
|
import time
|
||||||
|
from collections import OrderedDict
|
||||||
|
import os
|
||||||
|
|
||||||
|
from toolkit.config_modules import SliderConfig
|
||||||
|
from toolkit.paths import REPOS_ROOT
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(REPOS_ROOT)
|
||||||
|
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||||
|
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||||
|
import gc
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from leco import train_util, model_util
|
||||||
|
from leco.prompt_util import PromptEmbedsCache
|
||||||
|
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
|
||||||
|
|
||||||
|
|
||||||
|
def flush():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAHack:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.type = kwargs.get('type', 'suppression')
|
||||||
|
|
||||||
|
|
||||||
|
class TrainLoRAHack(BaseSDTrainProcess):
|
||||||
|
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||||
|
super().__init__(process_id, job, config)
|
||||||
|
self.hack_config = LoRAHack(**self.get_conf('hack', {}))
|
||||||
|
|
||||||
|
def hook_before_train_loop(self):
|
||||||
|
# we don't need text encoder so move it to cpu
|
||||||
|
self.sd.text_encoder.to("cpu")
|
||||||
|
flush()
|
||||||
|
# end hook_before_train_loop
|
||||||
|
|
||||||
|
if self.hack_config.type == 'suppression':
|
||||||
|
# set all params to self.current_suppression
|
||||||
|
params = self.network.parameters()
|
||||||
|
for param in params:
|
||||||
|
# get random noise for each param
|
||||||
|
noise = torch.randn_like(param) - 0.5
|
||||||
|
# apply noise to param
|
||||||
|
param.data = noise * 0.001
|
||||||
|
|
||||||
|
|
||||||
|
def supress_loop(self):
|
||||||
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
loss_dict = OrderedDict(
|
||||||
|
{'sup': 0.0}
|
||||||
|
)
|
||||||
|
# increase noise
|
||||||
|
for param in self.network.parameters():
|
||||||
|
# get random noise for each param
|
||||||
|
noise = torch.randn_like(param) - 0.5
|
||||||
|
# apply noise to param
|
||||||
|
param.data = param.data + noise * 0.001
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
return loss_dict
|
||||||
|
|
||||||
|
def hook_train_loop(self):
|
||||||
|
if self.hack_config.type == 'suppression':
|
||||||
|
return self.supress_loop()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'unknown hack type: {self.hack_config.type}')
|
||||||
|
# end hook_train_loop
|
||||||
@@ -3,19 +3,22 @@
|
|||||||
import time
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import os
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from toolkit.config_modules import SliderConfig
|
from toolkit.config_modules import SliderConfig
|
||||||
from toolkit.paths import REPOS_ROOT
|
from toolkit.paths import REPOS_ROOT
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||||
|
|
||||||
sys.path.append(REPOS_ROOT)
|
sys.path.append(REPOS_ROOT)
|
||||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||||
import gc
|
import gc
|
||||||
|
from toolkit import train_tools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from leco import train_util, model_util
|
from leco import train_util, model_util
|
||||||
from leco.prompt_util import PromptEmbedsCache
|
|
||||||
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
|
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
|
||||||
|
|
||||||
|
|
||||||
@@ -29,7 +32,6 @@ def flush():
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class EncodedPromptPair:
|
class EncodedPromptPair:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -54,6 +56,19 @@ class EncodedPromptPair:
|
|||||||
self.weight = weight
|
self.weight = weight
|
||||||
|
|
||||||
|
|
||||||
|
class PromptEmbedsCache: # 使いまわしたいので
|
||||||
|
prompts: dict[str, PromptEmbeds] = {}
|
||||||
|
|
||||||
|
def __setitem__(self, __name: str, __value: PromptEmbeds) -> None:
|
||||||
|
self.prompts[__name] = __value
|
||||||
|
|
||||||
|
def __getitem__(self, __name: str) -> Optional[PromptEmbeds]:
|
||||||
|
if __name in self.prompts:
|
||||||
|
return self.prompts[__name]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class EncodedAnchor:
|
class EncodedAnchor:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -89,19 +104,17 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
neutral = ""
|
neutral = ""
|
||||||
for target in self.slider_config.targets:
|
for target in self.slider_config.targets:
|
||||||
|
# build the cache
|
||||||
|
for prompt in [
|
||||||
|
target.target_class,
|
||||||
|
target.positive,
|
||||||
|
target.negative,
|
||||||
|
neutral # empty neutral
|
||||||
|
]:
|
||||||
|
if cache[prompt] is None:
|
||||||
|
cache[prompt] = self.sd.encode_prompt(prompt)
|
||||||
for resolution in self.slider_config.resolutions:
|
for resolution in self.slider_config.resolutions:
|
||||||
width, height = resolution
|
width, height = resolution
|
||||||
# build the cache
|
|
||||||
for prompt in [
|
|
||||||
target.target_class,
|
|
||||||
target.positive,
|
|
||||||
target.negative,
|
|
||||||
neutral # empty neutral
|
|
||||||
]:
|
|
||||||
if cache[prompt] == None:
|
|
||||||
cache[prompt] = train_util.encode_prompts(
|
|
||||||
self.sd.tokenizer, self.sd.text_encoder, [prompt]
|
|
||||||
)
|
|
||||||
only_erase = len(target.positive.strip()) == 0
|
only_erase = len(target.positive.strip()) == 0
|
||||||
only_enhance = len(target.negative.strip()) == 0
|
only_enhance = len(target.negative.strip()) == 0
|
||||||
|
|
||||||
@@ -184,9 +197,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
anchor.neg_prompt # empty neutral
|
anchor.neg_prompt # empty neutral
|
||||||
]:
|
]:
|
||||||
if cache[prompt] == None:
|
if cache[prompt] == None:
|
||||||
cache[prompt] = train_util.encode_prompts(
|
cache[prompt] = self.sd.encode_prompt(prompt)
|
||||||
self.sd.tokenizer, self.sd.text_encoder, [prompt]
|
|
||||||
)
|
|
||||||
|
|
||||||
anchor_pairs += [
|
anchor_pairs += [
|
||||||
EncodedAnchor(
|
EncodedAnchor(
|
||||||
@@ -198,7 +209,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
|
|
||||||
# move to cpu to save vram
|
# move to cpu to save vram
|
||||||
# We don't need text encoder anymore, but keep it on cpu for sampling
|
# We don't need text encoder anymore, but keep it on cpu for sampling
|
||||||
self.sd.text_encoder.to("cpu")
|
# if text encoder is list
|
||||||
|
if isinstance(self.sd.text_encoder, list):
|
||||||
|
for encoder in self.sd.text_encoder:
|
||||||
|
encoder.to("cpu")
|
||||||
|
else:
|
||||||
|
self.sd.text_encoder.to("cpu")
|
||||||
self.prompt_cache = cache
|
self.prompt_cache = cache
|
||||||
self.prompt_pairs = prompt_pairs
|
self.prompt_pairs = prompt_pairs
|
||||||
self.anchor_pairs = anchor_pairs
|
self.anchor_pairs = anchor_pairs
|
||||||
@@ -220,6 +236,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
negative = prompt_pair.negative
|
negative = prompt_pair.negative
|
||||||
positive = prompt_pair.positive
|
positive = prompt_pair.positive
|
||||||
weight = prompt_pair.weight
|
weight = prompt_pair.weight
|
||||||
|
multiplier = prompt_pair.multiplier
|
||||||
|
|
||||||
unet = self.sd.unet
|
unet = self.sd.unet
|
||||||
noise_scheduler = self.sd.noise_scheduler
|
noise_scheduler = self.sd.noise_scheduler
|
||||||
@@ -227,8 +244,20 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
lr_scheduler = self.lr_scheduler
|
lr_scheduler = self.lr_scheduler
|
||||||
loss_function = torch.nn.MSELoss()
|
loss_function = torch.nn.MSELoss()
|
||||||
|
|
||||||
|
def get_noise_pred(p, n):
|
||||||
|
return self.predict_noise(
|
||||||
|
latents=denoised_latents,
|
||||||
|
text_embeddings=train_tools.concat_prompt_embeddings(
|
||||||
|
p, # unconditional
|
||||||
|
n, # positive
|
||||||
|
self.train_config.batch_size,
|
||||||
|
),
|
||||||
|
timestep=current_timestep,
|
||||||
|
guidance_scale=1,
|
||||||
|
)
|
||||||
|
|
||||||
# set network multiplier
|
# set network multiplier
|
||||||
self.network.multiplier = prompt_pair.multiplier
|
self.network.multiplier = multiplier
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.sd.noise_scheduler.set_timesteps(
|
self.sd.noise_scheduler.set_timesteps(
|
||||||
@@ -254,9 +283,10 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
|
|
||||||
with self.network:
|
with self.network:
|
||||||
assert self.network.is_active
|
assert self.network.is_active
|
||||||
|
self.network.multiplier = multiplier
|
||||||
denoised_latents = self.diffuse_some_steps(
|
denoised_latents = self.diffuse_some_steps(
|
||||||
latents, # pass simple noise latents
|
latents, # pass simple noise latents
|
||||||
train_util.concat_embeddings(
|
train_tools.concat_prompt_embeddings(
|
||||||
positive, # unconditional
|
positive, # unconditional
|
||||||
target_class, # target
|
target_class, # target
|
||||||
self.train_config.batch_size,
|
self.train_config.batch_size,
|
||||||
@@ -272,43 +302,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
|
int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
|
||||||
]
|
]
|
||||||
|
|
||||||
# with network: 0 weight LoRA is enabled outside "with network:"
|
positive_latents = get_noise_pred(positive, negative)
|
||||||
positive_latents = train_util.predict_noise( # positive_latents
|
|
||||||
unet,
|
neutral_latents = get_noise_pred(positive, neutral)
|
||||||
noise_scheduler,
|
|
||||||
current_timestep,
|
unconditional_latents = get_noise_pred(positive, positive)
|
||||||
denoised_latents,
|
|
||||||
train_util.concat_embeddings(
|
|
||||||
positive, # unconditional
|
|
||||||
negative, # positive
|
|
||||||
self.train_config.batch_size,
|
|
||||||
),
|
|
||||||
guidance_scale=1,
|
|
||||||
).to("cpu", dtype=torch.float32)
|
|
||||||
neutral_latents = train_util.predict_noise(
|
|
||||||
unet,
|
|
||||||
noise_scheduler,
|
|
||||||
current_timestep,
|
|
||||||
denoised_latents,
|
|
||||||
train_util.concat_embeddings(
|
|
||||||
positive, # unconditional
|
|
||||||
neutral, # neutral
|
|
||||||
self.train_config.batch_size,
|
|
||||||
),
|
|
||||||
guidance_scale=1,
|
|
||||||
).to("cpu", dtype=torch.float32)
|
|
||||||
unconditional_latents = train_util.predict_noise(
|
|
||||||
unet,
|
|
||||||
noise_scheduler,
|
|
||||||
current_timestep,
|
|
||||||
denoised_latents,
|
|
||||||
train_util.concat_embeddings(
|
|
||||||
positive, # unconditional
|
|
||||||
positive, # unconditional
|
|
||||||
self.train_config.batch_size,
|
|
||||||
),
|
|
||||||
guidance_scale=1,
|
|
||||||
).to("cpu", dtype=torch.float32)
|
|
||||||
|
|
||||||
anchor_loss = None
|
anchor_loss = None
|
||||||
if len(self.anchor_pairs) > 0:
|
if len(self.anchor_pairs) > 0:
|
||||||
@@ -317,51 +315,19 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
torch.randint(0, len(self.anchor_pairs), (1,)).item()
|
torch.randint(0, len(self.anchor_pairs), (1,)).item()
|
||||||
]
|
]
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
anchor_target_noise = train_util.predict_noise(
|
anchor_target_noise = get_noise_pred(anchor.prompt, anchor.neg_prompt)
|
||||||
unet,
|
|
||||||
noise_scheduler,
|
|
||||||
current_timestep,
|
|
||||||
denoised_latents,
|
|
||||||
train_util.concat_embeddings(
|
|
||||||
anchor.prompt,
|
|
||||||
anchor.neg_prompt,
|
|
||||||
self.train_config.batch_size,
|
|
||||||
),
|
|
||||||
guidance_scale=1,
|
|
||||||
).to("cpu", dtype=torch.float32)
|
|
||||||
with self.network:
|
with self.network:
|
||||||
# anchor whatever weight prompt pair is using
|
# anchor whatever weight prompt pair is using
|
||||||
pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0
|
pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0
|
||||||
self.network.multiplier = anchor.multiplier * pos_nem_mult
|
self.network.multiplier = anchor.multiplier * pos_nem_mult
|
||||||
anchor_pred_noise = train_util.predict_noise(
|
|
||||||
unet,
|
anchor_pred_noise = get_noise_pred(anchor.prompt, anchor.neg_prompt)
|
||||||
noise_scheduler,
|
|
||||||
current_timestep,
|
|
||||||
denoised_latents,
|
|
||||||
train_util.concat_embeddings(
|
|
||||||
anchor.prompt,
|
|
||||||
anchor.neg_prompt,
|
|
||||||
self.train_config.batch_size,
|
|
||||||
),
|
|
||||||
guidance_scale=1,
|
|
||||||
).to("cpu", dtype=torch.float32)
|
|
||||||
|
|
||||||
self.network.multiplier = prompt_pair.multiplier
|
self.network.multiplier = prompt_pair.multiplier
|
||||||
|
|
||||||
with self.network:
|
with self.network:
|
||||||
self.network.multiplier = prompt_pair.multiplier
|
self.network.multiplier = prompt_pair.multiplier
|
||||||
target_latents = train_util.predict_noise(
|
target_latents = get_noise_pred(positive, target_class)
|
||||||
unet,
|
|
||||||
noise_scheduler,
|
|
||||||
current_timestep,
|
|
||||||
denoised_latents,
|
|
||||||
train_util.concat_embeddings(
|
|
||||||
positive, # unconditional
|
|
||||||
target_class, # target
|
|
||||||
self.train_config.batch_size,
|
|
||||||
),
|
|
||||||
guidance_scale=1,
|
|
||||||
).to("cpu", dtype=torch.float32)
|
|
||||||
|
|
||||||
# if self.logging_config.verbose:
|
# if self.logging_config.verbose:
|
||||||
# self.print("target_latents:", target_latents[0, 0, :5, :5])
|
# self.print("target_latents:", target_latents[0, 0, :5, :5])
|
||||||
|
|||||||
@@ -6,3 +6,4 @@ from .BaseTrainProcess import BaseTrainProcess
|
|||||||
from .TrainVAEProcess import TrainVAEProcess
|
from .TrainVAEProcess import TrainVAEProcess
|
||||||
from .BaseMergeProcess import BaseMergeProcess
|
from .BaseMergeProcess import BaseMergeProcess
|
||||||
from .TrainSliderProcess import TrainSliderProcess
|
from .TrainSliderProcess import TrainSliderProcess
|
||||||
|
from .TrainLoRAHack import TrainLoRAHack
|
||||||
|
|||||||
@@ -9,4 +9,5 @@ accelerator
|
|||||||
pyyaml
|
pyyaml
|
||||||
oyaml
|
oyaml
|
||||||
tensorboard
|
tensorboard
|
||||||
kornia
|
kornia
|
||||||
|
invisible-watermark
|
||||||
@@ -50,12 +50,14 @@ class TrainConfig:
|
|||||||
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
||||||
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
||||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||||
|
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.name_or_path: str = kwargs.get('name_or_path', None)
|
self.name_or_path: str = kwargs.get('name_or_path', None)
|
||||||
self.is_v2: bool = kwargs.get('is_v2', False)
|
self.is_v2: bool = kwargs.get('is_v2', False)
|
||||||
|
self.is_xl: bool = kwargs.get('is_xl', False)
|
||||||
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
|
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
|
||||||
|
|
||||||
if self.name_or_path is None:
|
if self.name_or_path is None:
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import List
|
from typing import List, Optional, Dict, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import CLIPTextModel
|
||||||
|
|
||||||
from .paths import SD_SCRIPTS_ROOT
|
from .paths import SD_SCRIPTS_ROOT
|
||||||
|
|
||||||
sys.path.append(SD_SCRIPTS_ROOT)
|
sys.path.append(SD_SCRIPTS_ROOT)
|
||||||
@@ -14,26 +16,40 @@ class LoRASpecialNetwork(LoRANetwork):
|
|||||||
_multiplier: float = 1.0
|
_multiplier: float = 1.0
|
||||||
is_active: bool = False
|
is_active: bool = False
|
||||||
|
|
||||||
|
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
||||||
|
|
||||||
|
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||||
|
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||||
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||||
|
LORA_PREFIX_UNET = "lora_unet"
|
||||||
|
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||||
|
|
||||||
|
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
||||||
|
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
|
||||||
|
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
text_encoder,
|
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
||||||
unet,
|
unet,
|
||||||
multiplier=1.0,
|
multiplier: float = 1.0,
|
||||||
lora_dim=4,
|
lora_dim: int = 4,
|
||||||
alpha=1,
|
alpha: float = 1,
|
||||||
dropout=None,
|
dropout: Optional[float] = None,
|
||||||
rank_dropout=None,
|
rank_dropout: Optional[float] = None,
|
||||||
module_dropout=None,
|
module_dropout: Optional[float] = None,
|
||||||
conv_lora_dim=None,
|
conv_lora_dim: Optional[int] = None,
|
||||||
conv_alpha=None,
|
conv_alpha: Optional[float] = None,
|
||||||
block_dims=None,
|
block_dims: Optional[List[int]] = None,
|
||||||
block_alphas=None,
|
block_alphas: Optional[List[float]] = None,
|
||||||
conv_block_dims=None,
|
conv_block_dims: Optional[List[int]] = None,
|
||||||
conv_block_alphas=None,
|
conv_block_alphas: Optional[List[float]] = None,
|
||||||
modules_dim=None,
|
modules_dim: Optional[Dict[str, int]] = None,
|
||||||
modules_alpha=None,
|
modules_alpha: Optional[Dict[str, int]] = None,
|
||||||
module_class=LoRAModule,
|
module_class: Type[object] = LoRAModule,
|
||||||
varbose=False,
|
varbose: Optional[bool] = False,
|
||||||
|
train_text_encoder: Optional[bool] = True,
|
||||||
|
train_unet: Optional[bool] = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
LoRA network: すごく引数が多いが、パターンは以下の通り
|
LoRA network: すごく引数が多いが、パターンは以下の通り
|
||||||
@@ -75,8 +91,21 @@ class LoRASpecialNetwork(LoRANetwork):
|
|||||||
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
def create_modules(
|
||||||
prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
is_unet: bool,
|
||||||
|
text_encoder_idx: Optional[int], # None, 1, 2
|
||||||
|
root_module: torch.nn.Module,
|
||||||
|
target_replace_modules: List[torch.nn.Module],
|
||||||
|
) -> List[LoRAModule]:
|
||||||
|
prefix = (
|
||||||
|
self.LORA_PREFIX_UNET
|
||||||
|
if is_unet
|
||||||
|
else (
|
||||||
|
self.LORA_PREFIX_TEXT_ENCODER
|
||||||
|
if text_encoder_idx is None
|
||||||
|
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
|
||||||
|
)
|
||||||
|
)
|
||||||
loras = []
|
loras = []
|
||||||
skipped = []
|
skipped = []
|
||||||
for name, module in root_module.named_modules():
|
for name, module in root_module.named_modules():
|
||||||
@@ -92,11 +121,14 @@ class LoRASpecialNetwork(LoRANetwork):
|
|||||||
|
|
||||||
dim = None
|
dim = None
|
||||||
alpha = None
|
alpha = None
|
||||||
|
|
||||||
if modules_dim is not None:
|
if modules_dim is not None:
|
||||||
|
# モジュール指定あり
|
||||||
if lora_name in modules_dim:
|
if lora_name in modules_dim:
|
||||||
dim = modules_dim[lora_name]
|
dim = modules_dim[lora_name]
|
||||||
alpha = modules_alpha[lora_name]
|
alpha = modules_alpha[lora_name]
|
||||||
elif is_unet and block_dims is not None:
|
elif is_unet and block_dims is not None:
|
||||||
|
# U-Netでblock_dims指定あり
|
||||||
block_idx = get_block_index(lora_name)
|
block_idx = get_block_index(lora_name)
|
||||||
if is_linear or is_conv2d_1x1:
|
if is_linear or is_conv2d_1x1:
|
||||||
dim = block_dims[block_idx]
|
dim = block_dims[block_idx]
|
||||||
@@ -105,6 +137,7 @@ class LoRASpecialNetwork(LoRANetwork):
|
|||||||
dim = conv_block_dims[block_idx]
|
dim = conv_block_dims[block_idx]
|
||||||
alpha = conv_block_alphas[block_idx]
|
alpha = conv_block_alphas[block_idx]
|
||||||
else:
|
else:
|
||||||
|
# 通常、すべて対象とする
|
||||||
if is_linear or is_conv2d_1x1:
|
if is_linear or is_conv2d_1x1:
|
||||||
dim = self.lora_dim
|
dim = self.lora_dim
|
||||||
alpha = self.alpha
|
alpha = self.alpha
|
||||||
@@ -113,6 +146,7 @@ class LoRASpecialNetwork(LoRANetwork):
|
|||||||
alpha = self.conv_alpha
|
alpha = self.conv_alpha
|
||||||
|
|
||||||
if dim is None or dim == 0:
|
if dim is None or dim == 0:
|
||||||
|
# skipした情報を出力
|
||||||
if is_linear or is_conv2d_1x1 or (
|
if is_linear or is_conv2d_1x1 or (
|
||||||
self.conv_lora_dim is not None or conv_block_dims is not None):
|
self.conv_lora_dim is not None or conv_block_dims is not None):
|
||||||
skipped.append(lora_name)
|
skipped.append(lora_name)
|
||||||
@@ -131,8 +165,25 @@ class LoRASpecialNetwork(LoRANetwork):
|
|||||||
loras.append(lora)
|
loras.append(lora)
|
||||||
return loras, skipped
|
return loras, skipped
|
||||||
|
|
||||||
self.text_encoder_loras, skipped_te = create_modules(False, text_encoder,
|
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
||||||
|
|
||||||
|
# create LoRA for text encoder
|
||||||
|
# 毎回すべてのモジュールを作るのは無駄なので要検討
|
||||||
|
self.text_encoder_loras = []
|
||||||
|
skipped_te = []
|
||||||
|
if train_text_encoder:
|
||||||
|
for i, text_encoder in enumerate(text_encoders):
|
||||||
|
if len(text_encoders) > 1:
|
||||||
|
index = i + 1
|
||||||
|
print(f"create LoRA for Text Encoder {index}:")
|
||||||
|
else:
|
||||||
|
index = None
|
||||||
|
print(f"create LoRA for Text Encoder:")
|
||||||
|
|
||||||
|
text_encoder_loras, skipped = create_modules(False, index, text_encoder,
|
||||||
LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||||
|
self.text_encoder_loras.extend(text_encoder_loras)
|
||||||
|
skipped_te += skipped
|
||||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||||
|
|
||||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||||
@@ -140,7 +191,11 @@ class LoRASpecialNetwork(LoRANetwork):
|
|||||||
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
||||||
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||||
|
|
||||||
self.unet_loras, skipped_un = create_modules(True, unet, target_modules)
|
if train_unet:
|
||||||
|
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||||
|
else:
|
||||||
|
self.unet_loras = []
|
||||||
|
skipped_un = []
|
||||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||||
|
|
||||||
skipped = skipped_te + skipped_un
|
skipped = skipped_te + skipped_un
|
||||||
@@ -159,8 +214,7 @@ class LoRASpecialNetwork(LoRANetwork):
|
|||||||
# assertion
|
# assertion
|
||||||
names = set()
|
names = set()
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
# doesnt work on new diffusers. TODO make sure we are not missing something
|
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||||
# assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
|
||||||
names.add(lora.lora_name)
|
names.add(lora.lora_name)
|
||||||
|
|
||||||
def save_weights(self, file, dtype, metadata):
|
def save_weights(self, file, dtype, metadata):
|
||||||
|
|||||||
63
toolkit/stable_diffusion_model.py
Normal file
63
toolkit/stable_diffusion_model.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
from typing import Union
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from toolkit.paths import REPOS_ROOT
|
||||||
|
sys.path.append(REPOS_ROOT)
|
||||||
|
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||||
|
from leco import train_util
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class PromptEmbeds:
|
||||||
|
text_embeds: torch.FloatTensor
|
||||||
|
pooled_embeds: Union[torch.FloatTensor, None]
|
||||||
|
|
||||||
|
def __init__(self, args) -> None:
|
||||||
|
if isinstance(args, list) or isinstance(args, tuple):
|
||||||
|
# xl
|
||||||
|
self.text_embeds = args[0]
|
||||||
|
self.pooled_embeds = args[1]
|
||||||
|
else:
|
||||||
|
# sdv1.x, sdv2.x
|
||||||
|
self.text_embeds = args
|
||||||
|
self.pooled_embeds = None
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusion:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vae,
|
||||||
|
tokenizer,
|
||||||
|
text_encoder,
|
||||||
|
unet,
|
||||||
|
noise_scheduler,
|
||||||
|
is_xl=False
|
||||||
|
):
|
||||||
|
# text encoder has a list of 2 for xl
|
||||||
|
self.vae = vae
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.text_encoder = text_encoder
|
||||||
|
self.unet = unet
|
||||||
|
self.noise_scheduler = noise_scheduler
|
||||||
|
self.is_xl = is_xl
|
||||||
|
|
||||||
|
def encode_prompt(self, prompt, num_images_per_prompt=1) -> PromptEmbeds:
|
||||||
|
prompt = prompt
|
||||||
|
# if it is not a list, make it one
|
||||||
|
if not isinstance(prompt, list):
|
||||||
|
prompt = [prompt]
|
||||||
|
if self.is_xl:
|
||||||
|
return PromptEmbeds(
|
||||||
|
train_util.encode_prompts_xl(
|
||||||
|
self.tokenizer,
|
||||||
|
self.text_encoder,
|
||||||
|
prompt,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return PromptEmbeds(
|
||||||
|
train_util.encode_prompts(
|
||||||
|
self.tokenizer, self.text_encoder, prompt
|
||||||
|
)
|
||||||
|
)
|
||||||
@@ -21,6 +21,8 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel
|
|||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||||
|
|
||||||
SCHEDULER_LINEAR_START = 0.00085
|
SCHEDULER_LINEAR_START = 0.00085
|
||||||
SCHEDULER_LINEAR_END = 0.0120
|
SCHEDULER_LINEAR_END = 0.0120
|
||||||
SCHEDULER_TIMESTEPS = 1000
|
SCHEDULER_TIMESTEPS = 1000
|
||||||
@@ -377,3 +379,19 @@ def apply_noise_offset(noise, noise_offset):
|
|||||||
return noise
|
return noise
|
||||||
noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device)
|
noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device)
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
|
|
||||||
|
def concat_prompt_embeddings(
|
||||||
|
unconditional: PromptEmbeds,
|
||||||
|
conditional: PromptEmbeds,
|
||||||
|
n_imgs: int,
|
||||||
|
):
|
||||||
|
text_embeds = torch.cat(
|
||||||
|
[unconditional.text_embeds, conditional.text_embeds]
|
||||||
|
).repeat_interleave(n_imgs, dim=0)
|
||||||
|
pooled_embeds = None
|
||||||
|
if unconditional.pooled_embeds is not None and conditional.pooled_embeds is not None:
|
||||||
|
pooled_embeds = torch.cat(
|
||||||
|
[unconditional.pooled_embeds, conditional.pooled_embeds]
|
||||||
|
).repeat_interleave(n_imgs, dim=0)
|
||||||
|
return PromptEmbeds([text_embeds, pooled_embeds])
|
||||||
|
|||||||
Reference in New Issue
Block a user