mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added rescaling, locon, sdxl, all kinds of stuff. sdxl is still weird
This commit is contained in:
@@ -18,6 +18,7 @@ process_dict = {
|
||||
'vae': 'TrainVAEProcess',
|
||||
'slider': 'TrainSliderProcess',
|
||||
'lora_hack': 'TrainLoRAHack',
|
||||
'rescale_sd': 'TrainSDRescaleProcess',
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import glob
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
|
||||
from safetensors import safe_open
|
||||
|
||||
from toolkit.kohya_model_util import load_vae
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.optimizer import get_optimizer
|
||||
@@ -14,7 +17,7 @@ sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import gc
|
||||
|
||||
@@ -48,6 +51,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.model_config = ModelConfig(**self.get_conf('model', {}))
|
||||
self.save_config = SaveConfig(**self.get_conf('save', {}))
|
||||
self.sample_config = SampleConfig(**self.get_conf('sample', {}))
|
||||
self.first_sample_config = SampleConfig(**self.get_conf('first_sample', {})) if 'first_sample' in self.config else self.sample_config
|
||||
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
@@ -56,7 +60,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# added later
|
||||
self.network = None
|
||||
|
||||
def sample(self, step=None):
|
||||
def sample(self, step=None, is_first=False):
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
if not os.path.exists(sample_folder):
|
||||
os.makedirs(sample_folder, exist_ok=True)
|
||||
@@ -112,7 +116,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# disable progress bar
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
start_seed = self.sample_config.seed
|
||||
sample_config = self.first_sample_config if is_first else self.sample_config
|
||||
|
||||
start_seed = sample_config.seed
|
||||
start_multiplier = self.network.multiplier
|
||||
current_seed = start_seed
|
||||
|
||||
@@ -127,14 +133,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
'multiplier': self.network.multiplier,
|
||||
})
|
||||
|
||||
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}",
|
||||
for i in tqdm(range(len(sample_config.prompts)), desc=f"Generating Samples - step: {step}",
|
||||
leave=False):
|
||||
raw_prompt = self.sample_config.prompts[i]
|
||||
raw_prompt = sample_config.prompts[i]
|
||||
|
||||
neg = self.sample_config.neg
|
||||
multiplier = self.sample_config.network_multiplier
|
||||
neg = sample_config.neg
|
||||
multiplier = sample_config.network_multiplier
|
||||
p_split = raw_prompt.split('--')
|
||||
prompt = p_split[0].strip()
|
||||
height = sample_config.height
|
||||
width = sample_config.width
|
||||
|
||||
if len(p_split) > 1:
|
||||
for split in p_split:
|
||||
@@ -145,13 +153,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
elif flag == 'm':
|
||||
# multiplier
|
||||
multiplier = float(content)
|
||||
elif flag == 'w':
|
||||
# multiplier
|
||||
width = int(content)
|
||||
elif flag == 'h':
|
||||
# multiplier
|
||||
height = int(content)
|
||||
|
||||
height = self.sample_config.height
|
||||
width = self.sample_config.width
|
||||
height = max(64, height - height % 8) # round to divisible by 8
|
||||
width = max(64, width - width % 8) # round to divisible by 8
|
||||
|
||||
if self.sample_config.walk_seed:
|
||||
if sample_config.walk_seed:
|
||||
current_seed += i
|
||||
|
||||
if self.network is not None:
|
||||
@@ -159,14 +171,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
torch.manual_seed(current_seed)
|
||||
torch.cuda.manual_seed(current_seed)
|
||||
|
||||
img = pipeline(
|
||||
prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=self.sample_config.sample_steps,
|
||||
guidance_scale=self.sample_config.guidance_scale,
|
||||
negative_prompt=neg,
|
||||
).images[0]
|
||||
if self.sd.is_xl:
|
||||
img = pipeline(
|
||||
prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=sample_config.sample_steps,
|
||||
guidance_scale=sample_config.guidance_scale,
|
||||
negative_prompt=neg,
|
||||
).images[0]
|
||||
else:
|
||||
img = pipeline(
|
||||
prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=sample_config.sample_steps,
|
||||
guidance_scale=sample_config.guidance_scale,
|
||||
negative_prompt=neg,
|
||||
).images[0]
|
||||
|
||||
step_num = ''
|
||||
if step is not None:
|
||||
@@ -209,6 +231,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
})
|
||||
return info
|
||||
|
||||
def clean_up_saves(self):
|
||||
# remove old saves
|
||||
# get latest saved step
|
||||
if os.path.exists(self.save_root):
|
||||
latest_file = None
|
||||
# pattern is {job_name}_{zero_filles_step}.safetensors but NOT {job_name}.safetensors
|
||||
pattern = f"{self.job.name}_*.safetensors"
|
||||
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||
if len(files) > self.save_config.max_step_saves_to_keep:
|
||||
# remove all but the latest max_step_saves_to_keep
|
||||
files.sort(key=os.path.getctime)
|
||||
for file in files[:-self.save_config.max_step_saves_to_keep]:
|
||||
self.print(f"Removing old save: {file}")
|
||||
os.remove(file)
|
||||
return latest_file
|
||||
else:
|
||||
return None
|
||||
|
||||
def save(self, step=None):
|
||||
if not os.path.exists(self.save_root):
|
||||
os.makedirs(self.save_root, exist_ok=True)
|
||||
@@ -231,9 +271,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
metadata=save_meta
|
||||
)
|
||||
else:
|
||||
# TODO handle dreambooth, fine tuning, etc
|
||||
# will probably have to convert dict back to LDM
|
||||
ValueError("Non network training is not currently supported")
|
||||
self.sd.save(
|
||||
file_path,
|
||||
save_meta,
|
||||
get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
|
||||
self.print(f"Saved to {file_path}")
|
||||
|
||||
@@ -258,6 +300,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
):
|
||||
if height is None and pixel_height is None:
|
||||
raise ValueError("height or pixel_height must be specified")
|
||||
raise ValueError("height or pixel_height must be specified")
|
||||
if width is None and pixel_width is None:
|
||||
raise ValueError("width or pixel_width must be specified")
|
||||
if height is None:
|
||||
@@ -316,18 +359,47 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if add_time_ids is None:
|
||||
add_time_ids = self.get_time_ids_from_latents(latents)
|
||||
# todo LECOs code looks like it is omitting noise_pred
|
||||
noise_pred = train_util.predict_noise_xl(
|
||||
self.sd.unet,
|
||||
self.sd.noise_scheduler,
|
||||
# noise_pred = train_util.predict_noise_xl(
|
||||
# self.sd.unet,
|
||||
# self.sd.noise_scheduler,
|
||||
# timestep,
|
||||
# latents,
|
||||
# text_embeddings.text_embeds,
|
||||
# text_embeddings.pooled_embeds,
|
||||
# add_time_ids,
|
||||
# guidance_scale=guidance_scale,
|
||||
# guidance_rescale=guidance_rescale
|
||||
# )
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
|
||||
latent_model_input = self.sd.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": text_embeddings.pooled_embeds,
|
||||
"time_ids": add_time_ids,
|
||||
}
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.sd.unet(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
latents,
|
||||
text_embeddings.text_embeds,
|
||||
text_embeddings.pooled_embeds,
|
||||
add_time_ids,
|
||||
guidance_scale=guidance_scale,
|
||||
guidance_rescale=guidance_rescale
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
guided_target = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
|
||||
# noise_pred = rescale_noise_cfg(
|
||||
# noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
||||
# )
|
||||
|
||||
noise_pred = guided_target
|
||||
|
||||
else:
|
||||
noise_pred = train_util.predict_noise(
|
||||
self.sd.unet,
|
||||
@@ -366,6 +438,32 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# return latents_steps
|
||||
return latents
|
||||
|
||||
def get_latest_save_path(self):
|
||||
# get latest saved step
|
||||
if os.path.exists(self.save_root):
|
||||
latest_file = None
|
||||
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
|
||||
pattern = f"{self.job.name}*.safetensors"
|
||||
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||
if len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getctime)
|
||||
return latest_file
|
||||
else:
|
||||
return None
|
||||
|
||||
def load_weights(self, path):
|
||||
if self.network is not None:
|
||||
self.network.load_weights(path)
|
||||
meta = load_metadata_from_safetensors(path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
else:
|
||||
print("load_weights not implemented for non-network models")
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
@@ -407,20 +505,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
unet.to(self.device_torch, dtype=dtype)
|
||||
if self.train_config.xformers:
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
if self.train_config.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
unet.requires_grad_(False)
|
||||
unet.eval()
|
||||
|
||||
if self.network_config is not None:
|
||||
conv = self.network_config.conv if self.network_config.conv is not None and self.network_config.conv > 0 else None
|
||||
self.network = LoRASpecialNetwork(
|
||||
text_encoder=text_encoder,
|
||||
unet=unet,
|
||||
lora_dim=self.network_config.rank,
|
||||
lora_dim=self.network_config.linear,
|
||||
multiplier=1.0,
|
||||
alpha=self.network_config.alpha,
|
||||
train_unet=self.train_config.train_unet,
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
conv_lora_dim=conv,
|
||||
conv_alpha=self.network_config.alpha if conv is not None else None,
|
||||
)
|
||||
|
||||
|
||||
self.network.force_to(self.device_torch, dtype=dtype)
|
||||
|
||||
self.network.apply_to(
|
||||
@@ -438,6 +542,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
default_lr=self.train_config.lr
|
||||
)
|
||||
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
if latest_save_path is not None:
|
||||
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
self.print(f"Loading from {latest_save_path}")
|
||||
self.load_weights(latest_save_path)
|
||||
self.network.multiplier = 1.0
|
||||
|
||||
|
||||
|
||||
else:
|
||||
params = []
|
||||
# assume dreambooth/finetune
|
||||
@@ -475,15 +588,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.print("Skipping first sample due to config setting")
|
||||
else:
|
||||
self.print("Generating baseline samples before training")
|
||||
self.sample(0)
|
||||
self.sample(0, is_first=True)
|
||||
|
||||
self.progress_bar = tqdm(
|
||||
total=self.train_config.steps,
|
||||
desc=self.job.name,
|
||||
leave=True
|
||||
)
|
||||
self.step_num = 0
|
||||
for step in range(self.train_config.steps):
|
||||
# set it to our current step in case it was updated from a load
|
||||
self.progress_bar.update(self.step_num)
|
||||
# self.step_num = 0
|
||||
for step in range(self.step_num, self.train_config.steps):
|
||||
# todo handle dataloader here maybe, not sure
|
||||
|
||||
### HOOK ###
|
||||
|
||||
278
jobs/process/TrainSDRescaleProcess.py
Normal file
278
jobs/process/TrainSDRescaleProcess.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# ref:
|
||||
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import SliderConfig
|
||||
from toolkit.layers import ReductionKernel
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
import sys
|
||||
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import gc
|
||||
from toolkit import train_tools
|
||||
|
||||
import torch
|
||||
from leco import train_util, model_util
|
||||
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class RescaleConfig:
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
self.from_resolution = kwargs.get('from_resolution', 512)
|
||||
self.scale = kwargs.get('scale', 0.5)
|
||||
self.prompt_file = kwargs.get('prompt_file', None)
|
||||
self.prompt_tensors = kwargs.get('prompt_tensors', None)
|
||||
self.to_resolution = kwargs.get('to_resolution', int(self.from_resolution * self.scale))
|
||||
|
||||
if self.prompt_file is None:
|
||||
raise ValueError("prompt_file is required")
|
||||
|
||||
|
||||
class PromptEmbedsCache:
|
||||
prompts: dict[str, PromptEmbeds] = {}
|
||||
|
||||
def __setitem__(self, __name: str, __value: PromptEmbeds) -> None:
|
||||
self.prompts[__name] = __value
|
||||
|
||||
def __getitem__(self, __name: str) -> Optional[PromptEmbeds]:
|
||||
if __name in self.prompts:
|
||||
return self.prompts[__name]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class TrainSDRescaleProcess(BaseSDTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
self.step_num = 0
|
||||
self.start_step = 0
|
||||
self.device = self.get_conf('device', self.job.device)
|
||||
self.device_torch = torch.device(self.device)
|
||||
self.prompt_cache = PromptEmbedsCache()
|
||||
self.rescale_config = RescaleConfig(**self.get_conf('rescale', required=True))
|
||||
self.reduce_size_fn = ReductionKernel(
|
||||
in_channels=4,
|
||||
kernel_size=int(self.rescale_config.from_resolution // self.rescale_config.to_resolution),
|
||||
dtype=get_torch_dtype(self.train_config.dtype),
|
||||
device=self.device_torch,
|
||||
)
|
||||
self.prompt_txt_list = []
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
self.print(f"Loading prompt file from {self.rescale_config.prompt_file}")
|
||||
|
||||
# read line by line from file
|
||||
with open(self.rescale_config.prompt_file, 'r') as f:
|
||||
self.prompt_txt_list = f.readlines()
|
||||
# clean empty lines
|
||||
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
|
||||
|
||||
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
|
||||
|
||||
cache = PromptEmbedsCache()
|
||||
|
||||
# get encoded latents for our prompts
|
||||
with torch.no_grad():
|
||||
if self.rescale_config.prompt_tensors is not None:
|
||||
# check to see if it exists
|
||||
if os.path.exists(self.rescale_config.prompt_tensors):
|
||||
# load it.
|
||||
self.print(f"Loading prompt tensors from {self.rescale_config.prompt_tensors}")
|
||||
prompt_tensors = load_file(self.rescale_config.prompt_tensors, device='cpu')
|
||||
# add them to the cache
|
||||
for prompt_txt, prompt_tensor in prompt_tensors.items():
|
||||
if prompt_txt.startswith("te:"):
|
||||
prompt = prompt_txt[3:]
|
||||
# text_embeds
|
||||
text_embeds = prompt_tensor
|
||||
pooled_embeds = None
|
||||
# find pool embeds
|
||||
if f"pe:{prompt}" in prompt_tensors:
|
||||
pooled_embeds = prompt_tensors[f"pe:{prompt}"]
|
||||
|
||||
# make it
|
||||
prompt_embeds = PromptEmbeds([text_embeds, pooled_embeds])
|
||||
cache[prompt] = prompt_embeds.to(device='cpu', dtype=torch.float32)
|
||||
|
||||
if len(cache.prompts) == 0:
|
||||
print("Prompt tensors not found. Encoding prompts..")
|
||||
neutral = ""
|
||||
# encode neutral
|
||||
cache[neutral] = self.sd.encode_prompt(neutral)
|
||||
for prompt in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False):
|
||||
# build the cache
|
||||
if cache[prompt] is None:
|
||||
cache[prompt] = self.sd.encode_prompt(prompt).to(device="cpu", dtype=torch.float32)
|
||||
|
||||
if self.rescale_config.prompt_tensors:
|
||||
print(f"Saving prompt tensors to {self.rescale_config.prompt_tensors}")
|
||||
state_dict = {}
|
||||
for prompt_txt, prompt_embeds in cache.prompts.items():
|
||||
state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to("cpu", dtype=get_torch_dtype('fp16'))
|
||||
if prompt_embeds.pooled_embeds is not None:
|
||||
state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to("cpu", dtype=get_torch_dtype('fp16'))
|
||||
save_file(state_dict, self.rescale_config.prompt_tensors)
|
||||
|
||||
self.print("Encoding complete.")
|
||||
|
||||
# move to cpu to save vram
|
||||
# We don't need text encoder anymore, but keep it on cpu for sampling
|
||||
# if text encoder is list
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for encoder in self.sd.text_encoder:
|
||||
encoder.to("cpu")
|
||||
else:
|
||||
self.sd.text_encoder.to("cpu")
|
||||
self.prompt_cache = cache
|
||||
|
||||
flush()
|
||||
# end hook_before_train_loop
|
||||
|
||||
def hook_train_loop(self):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# get random encoded prompt from cache
|
||||
prompt_txt = self.prompt_txt_list[
|
||||
torch.randint(0, len(self.prompt_txt_list), (1,)).item()
|
||||
]
|
||||
prompt = self.prompt_cache[prompt_txt].to(device=self.device_torch, dtype=dtype)
|
||||
neutral = self.prompt_cache[""].to(device=self.device_torch, dtype=dtype)
|
||||
if prompt is None:
|
||||
raise ValueError(f"Prompt {prompt_txt} is not in cache")
|
||||
|
||||
prompt_batch = train_tools.concat_prompt_embeddings(
|
||||
prompt,
|
||||
neutral,
|
||||
self.train_config.batch_size,
|
||||
)
|
||||
|
||||
noise_scheduler = self.sd.noise_scheduler
|
||||
optimizer = self.optimizer
|
||||
lr_scheduler = self.lr_scheduler
|
||||
loss_function = torch.nn.MSELoss()
|
||||
|
||||
def get_noise_pred(p, n, gs, cts, dn):
|
||||
return self.predict_noise(
|
||||
latents=dn,
|
||||
text_embeddings=train_tools.concat_prompt_embeddings(
|
||||
p, # unconditional
|
||||
n, # positive
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
timestep=cts,
|
||||
guidance_scale=gs,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# # ger a random number of steps
|
||||
timesteps_to = torch.randint(
|
||||
1, self.train_config.max_denoising_steps, (1,)
|
||||
).item()
|
||||
|
||||
# get noise
|
||||
noise = self.get_latent_noise(
|
||||
pixel_height=self.rescale_config.from_resolution,
|
||||
pixel_width=self.rescale_config.from_resolution,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
# get latents
|
||||
latents = noise * self.sd.noise_scheduler.init_noise_sigma
|
||||
latents = latents.to(self.device_torch, dtype=dtype)
|
||||
#
|
||||
# # predict without network
|
||||
# assert self.network.is_active is False
|
||||
# denoised_latents = self.diffuse_some_steps(
|
||||
# latents, # pass simple noise latents
|
||||
# prompt_batch,
|
||||
# start_timesteps=0,
|
||||
# total_timesteps=timesteps_to,
|
||||
# guidance_scale=3,
|
||||
# )
|
||||
# noise_scheduler.set_timesteps(1000)
|
||||
#
|
||||
# current_timestep = noise_scheduler.timesteps[
|
||||
# int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
|
||||
# ]
|
||||
|
||||
current_timestep = 0
|
||||
denoised_latents = latents
|
||||
# get noise prediction at full scale
|
||||
from_prediction = get_noise_pred(
|
||||
prompt, neutral, 1, current_timestep, denoised_latents
|
||||
)
|
||||
|
||||
reduced_from_prediction = self.reduce_size_fn(from_prediction).to("cpu", dtype=torch.float32)
|
||||
|
||||
# get noise prediction at reduced scale
|
||||
to_denoised_latents = self.reduce_size_fn(denoised_latents)
|
||||
|
||||
# start gradient
|
||||
optimizer.zero_grad()
|
||||
self.network.multiplier = 1.0
|
||||
with self.network:
|
||||
assert self.network.is_active is True
|
||||
to_prediction = get_noise_pred(
|
||||
prompt, neutral, 1, current_timestep, to_denoised_latents
|
||||
).to("cpu", dtype=torch.float32)
|
||||
|
||||
reduced_from_prediction.requires_grad = False
|
||||
from_prediction.requires_grad = False
|
||||
|
||||
loss = loss_function(
|
||||
reduced_from_prediction,
|
||||
to_prediction,
|
||||
)
|
||||
|
||||
loss_float = loss.item()
|
||||
|
||||
loss = loss.to(self.device_torch)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
del (
|
||||
reduced_from_prediction,
|
||||
from_prediction,
|
||||
to_denoised_latents,
|
||||
to_prediction,
|
||||
latents,
|
||||
)
|
||||
flush()
|
||||
|
||||
# reset network
|
||||
self.network.multiplier = 1.0
|
||||
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss_float},
|
||||
)
|
||||
|
||||
return loss_dict
|
||||
# end hook_train_loop
|
||||
@@ -669,7 +669,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
if self.writer is not None:
|
||||
# get avg loss
|
||||
for key in log_losses:
|
||||
log_losses[key] = sum(log_losses[key]) / len(log_losses[key])
|
||||
log_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + 1e-6)
|
||||
# if log_losses[key] > 0:
|
||||
self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num)
|
||||
# reset log losses
|
||||
@@ -678,9 +678,10 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.step_num += 1
|
||||
# end epoch
|
||||
if self.writer is not None:
|
||||
eps = 1e-6
|
||||
# get avg loss
|
||||
for key in epoch_losses:
|
||||
epoch_losses[key] = sum(log_losses[key]) / len(log_losses[key])
|
||||
epoch_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + eps)
|
||||
if epoch_losses[key] > 0:
|
||||
self.writer.add_scalar(f"epoch loss/{key}", epoch_losses[key], epoch)
|
||||
# reset epoch losses
|
||||
|
||||
@@ -7,3 +7,4 @@ from .TrainVAEProcess import TrainVAEProcess
|
||||
from .BaseMergeProcess import BaseMergeProcess
|
||||
from .TrainSliderProcess import TrainSliderProcess
|
||||
from .TrainLoRAHack import TrainLoRAHack
|
||||
from .TrainSDRescaleProcess import TrainSDRescaleProcess
|
||||
@@ -10,4 +10,6 @@ pyyaml
|
||||
oyaml
|
||||
tensorboard
|
||||
kornia
|
||||
invisible-watermark
|
||||
invisible-watermark
|
||||
einops
|
||||
accelerate
|
||||
|
||||
@@ -5,6 +5,7 @@ class SaveConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.save_every: int = kwargs.get('save_every', 1000)
|
||||
self.dtype: str = kwargs.get('save_dtype', 'float16')
|
||||
self.max_step_saves_to_keep: int = kwargs.get('max_step_saves_to_keep', 5)
|
||||
|
||||
|
||||
class LogingConfig:
|
||||
@@ -30,8 +31,16 @@ class SampleConfig:
|
||||
|
||||
class NetworkConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.type: str = kwargs.get('type', 'lierla')
|
||||
self.rank: int = kwargs.get('rank', 4)
|
||||
self.type: str = kwargs.get('type', 'lora')
|
||||
rank = kwargs.get('rank', None)
|
||||
linear = kwargs.get('linear', None)
|
||||
if rank is not None:
|
||||
self.rank: int = rank # rank for backward compatibility
|
||||
self.linear: int = rank
|
||||
elif linear is not None:
|
||||
self.rank: int = linear
|
||||
self.linear: int = linear
|
||||
self.conv: int = kwargs.get('conv', None)
|
||||
self.alpha: float = kwargs.get('alpha', 1.0)
|
||||
|
||||
|
||||
@@ -51,6 +60,7 @@ class TrainConfig:
|
||||
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
||||
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', False)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
|
||||
31
toolkit/layers.py
Normal file
31
toolkit/layers.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ReductionKernel(nn.Module):
|
||||
# Tensorflow
|
||||
def __init__(self, in_channels, kernel_size=2, dtype=torch.float32, device=None):
|
||||
if device is None:
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
super(ReductionKernel, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.in_channels = in_channels
|
||||
numpy_kernel = self.build_kernel()
|
||||
self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
||||
|
||||
def build_kernel(self):
|
||||
# tensorflow kernel is (height, width, in_channels, out_channels)
|
||||
# pytorch kernel is (out_channels, in_channels, height, width)
|
||||
kernel_size = self.kernel_size
|
||||
channels = self.in_channels
|
||||
kernel_shape = [channels, channels, kernel_size, kernel_size]
|
||||
kernel = np.zeros(kernel_shape, np.float32)
|
||||
|
||||
kernel_value = 1.0 / (kernel_size * kernel_size)
|
||||
for i in range(0, channels):
|
||||
kernel[i, i, :, :] = kernel_value
|
||||
return kernel
|
||||
|
||||
def forward(self, x):
|
||||
return nn.functional.conv2d(x, self.kernel, stride=self.kernel_size, padding=0, groups=1)
|
||||
@@ -1,5 +1,8 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
|
||||
from safetensors import safe_open
|
||||
|
||||
from info import software_meta
|
||||
|
||||
|
||||
@@ -25,4 +28,10 @@ def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict:
|
||||
parsed_meta[key] = json.loads(value)
|
||||
except json.decoder.JSONDecodeError:
|
||||
parsed_meta[key] = value
|
||||
return meta
|
||||
return parsed_meta
|
||||
|
||||
|
||||
def load_metadata_from_safetensors(file_path: str) -> OrderedDict:
|
||||
with safe_open(file_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
return parse_metadata_from_safetensors(metadata)
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
from typing import Union
|
||||
from typing import Union, OrderedDict
|
||||
import sys
|
||||
import os
|
||||
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
from leco import train_util
|
||||
import torch
|
||||
from library import model_util
|
||||
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
|
||||
|
||||
|
||||
class PromptEmbeds:
|
||||
@@ -22,6 +29,12 @@ class PromptEmbeds:
|
||||
self.text_embeds = args
|
||||
self.pooled_embeds = None
|
||||
|
||||
def to(self, **kwargs):
|
||||
self.text_embeds = self.text_embeds.to(**kwargs)
|
||||
if self.pooled_embeds is not None:
|
||||
self.pooled_embeds = self.pooled_embeds.to(**kwargs)
|
||||
return self
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
def __init__(
|
||||
@@ -61,3 +74,41 @@ class StableDiffusion:
|
||||
self.tokenizer, self.text_encoder, prompt
|
||||
)
|
||||
)
|
||||
|
||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||
# todo see what logit scale is
|
||||
if self.is_xl:
|
||||
|
||||
state_dict = {}
|
||||
|
||||
def update_sd(prefix, sd):
|
||||
for k, v in sd.items():
|
||||
key = prefix + k
|
||||
v = v.detach().clone().to("cpu").to(get_torch_dtype(save_dtype))
|
||||
state_dict[key] = v
|
||||
|
||||
# Convert the UNet model
|
||||
update_sd("model.diffusion_model.", self.unet.state_dict())
|
||||
|
||||
# Convert the text encoders
|
||||
update_sd("conditioner.embedders.0.transformer.", self.text_encoder[0].state_dict())
|
||||
|
||||
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(self.text_encoder[1].state_dict(), logit_scale)
|
||||
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
|
||||
|
||||
# Convert the VAE
|
||||
vae_dict = model_util.convert_vae_state_dict(self.vae.state_dict())
|
||||
update_sd("first_stage_model.", vae_dict)
|
||||
|
||||
# Put together new checkpoint
|
||||
key_count = len(state_dict.keys())
|
||||
new_ckpt = {"state_dict": state_dict}
|
||||
|
||||
if model_util.is_safetensors(output_file):
|
||||
save_file(state_dict, output_file)
|
||||
else:
|
||||
torch.save(new_ckpt, output_file, meta)
|
||||
|
||||
return key_count
|
||||
else:
|
||||
raise NotImplementedError("sdv1.x, sdv2.x is not implemented yet")
|
||||
|
||||
@@ -2,6 +2,7 @@ import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
@@ -21,8 +22,6 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel
|
||||
import torch
|
||||
import re
|
||||
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
|
||||
SCHEDULER_LINEAR_START = 0.00085
|
||||
SCHEDULER_LINEAR_END = 0.0120
|
||||
SCHEDULER_TIMESTEPS = 1000
|
||||
@@ -381,11 +380,16 @@ def apply_noise_offset(noise, noise_offset):
|
||||
return noise
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
|
||||
|
||||
def concat_prompt_embeddings(
|
||||
unconditional: PromptEmbeds,
|
||||
conditional: PromptEmbeds,
|
||||
unconditional: 'PromptEmbeds',
|
||||
conditional: 'PromptEmbeds',
|
||||
n_imgs: int,
|
||||
):
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
text_embeds = torch.cat(
|
||||
[unconditional.text_embeds, conditional.text_embeds]
|
||||
).repeat_interleave(n_imgs, dim=0)
|
||||
|
||||
Reference in New Issue
Block a user