Added rescaling, locon, sdxl, all kinds of stuff. sdxl is still weird

This commit is contained in:
Jaret Burkett
2023-07-26 16:19:50 -06:00
parent 40e60fa021
commit d3ad195b51
11 changed files with 548 additions and 45 deletions

View File

@@ -18,6 +18,7 @@ process_dict = {
'vae': 'TrainVAEProcess', 'vae': 'TrainVAEProcess',
'slider': 'TrainSliderProcess', 'slider': 'TrainSliderProcess',
'lora_hack': 'TrainLoRAHack', 'lora_hack': 'TrainLoRAHack',
'rescale_sd': 'TrainSDRescaleProcess',
} }

View File

@@ -1,7 +1,10 @@
import glob
import time import time
from collections import OrderedDict from collections import OrderedDict
import os import os
from safetensors import safe_open
from toolkit.kohya_model_util import load_vae from toolkit.kohya_model_util import load_vae
from toolkit.lora_special import LoRASpecialNetwork from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer from toolkit.optimizer import get_optimizer
@@ -14,7 +17,7 @@ sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from jobs.process import BaseTrainProcess from jobs.process import BaseTrainProcess
from toolkit.metadata import get_meta_for_safetensors from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
from toolkit.train_tools import get_torch_dtype, apply_noise_offset from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import gc import gc
@@ -48,6 +51,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.model_config = ModelConfig(**self.get_conf('model', {})) self.model_config = ModelConfig(**self.get_conf('model', {}))
self.save_config = SaveConfig(**self.get_conf('save', {})) self.save_config = SaveConfig(**self.get_conf('save', {}))
self.sample_config = SampleConfig(**self.get_conf('sample', {})) self.sample_config = SampleConfig(**self.get_conf('sample', {}))
self.first_sample_config = SampleConfig(**self.get_conf('first_sample', {})) if 'first_sample' in self.config else self.sample_config
self.logging_config = LogingConfig(**self.get_conf('logging', {})) self.logging_config = LogingConfig(**self.get_conf('logging', {}))
self.optimizer = None self.optimizer = None
self.lr_scheduler = None self.lr_scheduler = None
@@ -56,7 +60,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# added later # added later
self.network = None self.network = None
def sample(self, step=None): def sample(self, step=None, is_first=False):
sample_folder = os.path.join(self.save_root, 'samples') sample_folder = os.path.join(self.save_root, 'samples')
if not os.path.exists(sample_folder): if not os.path.exists(sample_folder):
os.makedirs(sample_folder, exist_ok=True) os.makedirs(sample_folder, exist_ok=True)
@@ -112,7 +116,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
# disable progress bar # disable progress bar
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
start_seed = self.sample_config.seed sample_config = self.first_sample_config if is_first else self.sample_config
start_seed = sample_config.seed
start_multiplier = self.network.multiplier start_multiplier = self.network.multiplier
current_seed = start_seed current_seed = start_seed
@@ -127,14 +133,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
'multiplier': self.network.multiplier, 'multiplier': self.network.multiplier,
}) })
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}", for i in tqdm(range(len(sample_config.prompts)), desc=f"Generating Samples - step: {step}",
leave=False): leave=False):
raw_prompt = self.sample_config.prompts[i] raw_prompt = sample_config.prompts[i]
neg = self.sample_config.neg neg = sample_config.neg
multiplier = self.sample_config.network_multiplier multiplier = sample_config.network_multiplier
p_split = raw_prompt.split('--') p_split = raw_prompt.split('--')
prompt = p_split[0].strip() prompt = p_split[0].strip()
height = sample_config.height
width = sample_config.width
if len(p_split) > 1: if len(p_split) > 1:
for split in p_split: for split in p_split:
@@ -145,13 +153,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
elif flag == 'm': elif flag == 'm':
# multiplier # multiplier
multiplier = float(content) multiplier = float(content)
elif flag == 'w':
# multiplier
width = int(content)
elif flag == 'h':
# multiplier
height = int(content)
height = self.sample_config.height
width = self.sample_config.width
height = max(64, height - height % 8) # round to divisible by 8 height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8 width = max(64, width - width % 8) # round to divisible by 8
if self.sample_config.walk_seed: if sample_config.walk_seed:
current_seed += i current_seed += i
if self.network is not None: if self.network is not None:
@@ -159,14 +171,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
torch.manual_seed(current_seed) torch.manual_seed(current_seed)
torch.cuda.manual_seed(current_seed) torch.cuda.manual_seed(current_seed)
img = pipeline( if self.sd.is_xl:
prompt, img = pipeline(
height=height, prompt,
width=width, height=height,
num_inference_steps=self.sample_config.sample_steps, width=width,
guidance_scale=self.sample_config.guidance_scale, num_inference_steps=sample_config.sample_steps,
negative_prompt=neg, guidance_scale=sample_config.guidance_scale,
).images[0] negative_prompt=neg,
).images[0]
else:
img = pipeline(
prompt,
height=height,
width=width,
num_inference_steps=sample_config.sample_steps,
guidance_scale=sample_config.guidance_scale,
negative_prompt=neg,
).images[0]
step_num = '' step_num = ''
if step is not None: if step is not None:
@@ -209,6 +231,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
}) })
return info return info
def clean_up_saves(self):
# remove old saves
# get latest saved step
if os.path.exists(self.save_root):
latest_file = None
# pattern is {job_name}_{zero_filles_step}.safetensors but NOT {job_name}.safetensors
pattern = f"{self.job.name}_*.safetensors"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > self.save_config.max_step_saves_to_keep:
# remove all but the latest max_step_saves_to_keep
files.sort(key=os.path.getctime)
for file in files[:-self.save_config.max_step_saves_to_keep]:
self.print(f"Removing old save: {file}")
os.remove(file)
return latest_file
else:
return None
def save(self, step=None): def save(self, step=None):
if not os.path.exists(self.save_root): if not os.path.exists(self.save_root):
os.makedirs(self.save_root, exist_ok=True) os.makedirs(self.save_root, exist_ok=True)
@@ -231,9 +271,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
metadata=save_meta metadata=save_meta
) )
else: else:
# TODO handle dreambooth, fine tuning, etc self.sd.save(
# will probably have to convert dict back to LDM file_path,
ValueError("Non network training is not currently supported") save_meta,
get_torch_dtype(self.save_config.dtype)
)
self.print(f"Saved to {file_path}") self.print(f"Saved to {file_path}")
@@ -258,6 +300,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
): ):
if height is None and pixel_height is None: if height is None and pixel_height is None:
raise ValueError("height or pixel_height must be specified") raise ValueError("height or pixel_height must be specified")
raise ValueError("height or pixel_height must be specified")
if width is None and pixel_width is None: if width is None and pixel_width is None:
raise ValueError("width or pixel_width must be specified") raise ValueError("width or pixel_width must be specified")
if height is None: if height is None:
@@ -316,18 +359,47 @@ class BaseSDTrainProcess(BaseTrainProcess):
if add_time_ids is None: if add_time_ids is None:
add_time_ids = self.get_time_ids_from_latents(latents) add_time_ids = self.get_time_ids_from_latents(latents)
# todo LECOs code looks like it is omitting noise_pred # todo LECOs code looks like it is omitting noise_pred
noise_pred = train_util.predict_noise_xl( # noise_pred = train_util.predict_noise_xl(
self.sd.unet, # self.sd.unet,
self.sd.noise_scheduler, # self.sd.noise_scheduler,
# timestep,
# latents,
# text_embeddings.text_embeds,
# text_embeddings.pooled_embeds,
# add_time_ids,
# guidance_scale=guidance_scale,
# guidance_rescale=guidance_rescale
# )
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.sd.noise_scheduler.scale_model_input(latent_model_input, timestep)
added_cond_kwargs = {
"text_embeds": text_embeddings.pooled_embeds,
"time_ids": add_time_ids,
}
# predict the noise residual
noise_pred = self.sd.unet(
latent_model_input,
timestep, timestep,
latents, encoder_hidden_states=text_embeddings.text_embeds,
text_embeddings.text_embeds, added_cond_kwargs=added_cond_kwargs,
text_embeddings.pooled_embeds, ).sample
add_time_ids,
guidance_scale=guidance_scale, # perform guidance
guidance_rescale=guidance_rescale noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
guided_target = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
) )
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
# noise_pred = rescale_noise_cfg(
# noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
# )
noise_pred = guided_target
else: else:
noise_pred = train_util.predict_noise( noise_pred = train_util.predict_noise(
self.sd.unet, self.sd.unet,
@@ -366,6 +438,32 @@ class BaseSDTrainProcess(BaseTrainProcess):
# return latents_steps # return latents_steps
return latents return latents
def get_latest_save_path(self):
# get latest saved step
if os.path.exists(self.save_root):
latest_file = None
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
pattern = f"{self.job.name}*.safetensors"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > 0:
latest_file = max(files, key=os.path.getctime)
return latest_file
else:
return None
def load_weights(self, path):
if self.network is not None:
self.network.load_weights(path)
meta = load_metadata_from_safetensors(path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
else:
print("load_weights not implemented for non-network models")
def run(self): def run(self):
super().run() super().run()
@@ -407,20 +505,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
unet.to(self.device_torch, dtype=dtype) unet.to(self.device_torch, dtype=dtype)
if self.train_config.xformers: if self.train_config.xformers:
unet.enable_xformers_memory_efficient_attention() unet.enable_xformers_memory_efficient_attention()
if self.train_config.gradient_checkpointing:
unet.enable_gradient_checkpointing()
unet.requires_grad_(False) unet.requires_grad_(False)
unet.eval() unet.eval()
if self.network_config is not None: if self.network_config is not None:
conv = self.network_config.conv if self.network_config.conv is not None and self.network_config.conv > 0 else None
self.network = LoRASpecialNetwork( self.network = LoRASpecialNetwork(
text_encoder=text_encoder, text_encoder=text_encoder,
unet=unet, unet=unet,
lora_dim=self.network_config.rank, lora_dim=self.network_config.linear,
multiplier=1.0, multiplier=1.0,
alpha=self.network_config.alpha, alpha=self.network_config.alpha,
train_unet=self.train_config.train_unet, train_unet=self.train_config.train_unet,
train_text_encoder=self.train_config.train_text_encoder, train_text_encoder=self.train_config.train_text_encoder,
conv_lora_dim=conv,
conv_alpha=self.network_config.alpha if conv is not None else None,
) )
self.network.force_to(self.device_torch, dtype=dtype) self.network.force_to(self.device_torch, dtype=dtype)
self.network.apply_to( self.network.apply_to(
@@ -438,6 +542,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
default_lr=self.train_config.lr default_lr=self.train_config.lr
) )
latest_save_path = self.get_latest_save_path()
if latest_save_path is not None:
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
self.print(f"Loading from {latest_save_path}")
self.load_weights(latest_save_path)
self.network.multiplier = 1.0
else: else:
params = [] params = []
# assume dreambooth/finetune # assume dreambooth/finetune
@@ -475,15 +588,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.print("Skipping first sample due to config setting") self.print("Skipping first sample due to config setting")
else: else:
self.print("Generating baseline samples before training") self.print("Generating baseline samples before training")
self.sample(0) self.sample(0, is_first=True)
self.progress_bar = tqdm( self.progress_bar = tqdm(
total=self.train_config.steps, total=self.train_config.steps,
desc=self.job.name, desc=self.job.name,
leave=True leave=True
) )
self.step_num = 0 # set it to our current step in case it was updated from a load
for step in range(self.train_config.steps): self.progress_bar.update(self.step_num)
# self.step_num = 0
for step in range(self.step_num, self.train_config.steps):
# todo handle dataloader here maybe, not sure # todo handle dataloader here maybe, not sure
### HOOK ### ### HOOK ###

View File

@@ -0,0 +1,278 @@
# ref:
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
import time
from collections import OrderedDict
import os
from typing import Optional
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from toolkit.config_modules import SliderConfig
from toolkit.layers import ReductionKernel
from toolkit.paths import REPOS_ROOT
import sys
from toolkit.stable_diffusion_model import PromptEmbeds
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import gc
from toolkit import train_tools
import torch
from leco import train_util, model_util
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
def flush():
torch.cuda.empty_cache()
gc.collect()
class RescaleConfig:
def __init__(
self,
**kwargs
):
self.from_resolution = kwargs.get('from_resolution', 512)
self.scale = kwargs.get('scale', 0.5)
self.prompt_file = kwargs.get('prompt_file', None)
self.prompt_tensors = kwargs.get('prompt_tensors', None)
self.to_resolution = kwargs.get('to_resolution', int(self.from_resolution * self.scale))
if self.prompt_file is None:
raise ValueError("prompt_file is required")
class PromptEmbedsCache:
prompts: dict[str, PromptEmbeds] = {}
def __setitem__(self, __name: str, __value: PromptEmbeds) -> None:
self.prompts[__name] = __value
def __getitem__(self, __name: str) -> Optional[PromptEmbeds]:
if __name in self.prompts:
return self.prompts[__name]
else:
return None
class TrainSDRescaleProcess(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.step_num = 0
self.start_step = 0
self.device = self.get_conf('device', self.job.device)
self.device_torch = torch.device(self.device)
self.prompt_cache = PromptEmbedsCache()
self.rescale_config = RescaleConfig(**self.get_conf('rescale', required=True))
self.reduce_size_fn = ReductionKernel(
in_channels=4,
kernel_size=int(self.rescale_config.from_resolution // self.rescale_config.to_resolution),
dtype=get_torch_dtype(self.train_config.dtype),
device=self.device_torch,
)
self.prompt_txt_list = []
def before_model_load(self):
pass
def hook_before_train_loop(self):
self.print(f"Loading prompt file from {self.rescale_config.prompt_file}")
# read line by line from file
with open(self.rescale_config.prompt_file, 'r') as f:
self.prompt_txt_list = f.readlines()
# clean empty lines
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
cache = PromptEmbedsCache()
# get encoded latents for our prompts
with torch.no_grad():
if self.rescale_config.prompt_tensors is not None:
# check to see if it exists
if os.path.exists(self.rescale_config.prompt_tensors):
# load it.
self.print(f"Loading prompt tensors from {self.rescale_config.prompt_tensors}")
prompt_tensors = load_file(self.rescale_config.prompt_tensors, device='cpu')
# add them to the cache
for prompt_txt, prompt_tensor in prompt_tensors.items():
if prompt_txt.startswith("te:"):
prompt = prompt_txt[3:]
# text_embeds
text_embeds = prompt_tensor
pooled_embeds = None
# find pool embeds
if f"pe:{prompt}" in prompt_tensors:
pooled_embeds = prompt_tensors[f"pe:{prompt}"]
# make it
prompt_embeds = PromptEmbeds([text_embeds, pooled_embeds])
cache[prompt] = prompt_embeds.to(device='cpu', dtype=torch.float32)
if len(cache.prompts) == 0:
print("Prompt tensors not found. Encoding prompts..")
neutral = ""
# encode neutral
cache[neutral] = self.sd.encode_prompt(neutral)
for prompt in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False):
# build the cache
if cache[prompt] is None:
cache[prompt] = self.sd.encode_prompt(prompt).to(device="cpu", dtype=torch.float32)
if self.rescale_config.prompt_tensors:
print(f"Saving prompt tensors to {self.rescale_config.prompt_tensors}")
state_dict = {}
for prompt_txt, prompt_embeds in cache.prompts.items():
state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to("cpu", dtype=get_torch_dtype('fp16'))
if prompt_embeds.pooled_embeds is not None:
state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to("cpu", dtype=get_torch_dtype('fp16'))
save_file(state_dict, self.rescale_config.prompt_tensors)
self.print("Encoding complete.")
# move to cpu to save vram
# We don't need text encoder anymore, but keep it on cpu for sampling
# if text encoder is list
if isinstance(self.sd.text_encoder, list):
for encoder in self.sd.text_encoder:
encoder.to("cpu")
else:
self.sd.text_encoder.to("cpu")
self.prompt_cache = cache
flush()
# end hook_before_train_loop
def hook_train_loop(self):
dtype = get_torch_dtype(self.train_config.dtype)
# get random encoded prompt from cache
prompt_txt = self.prompt_txt_list[
torch.randint(0, len(self.prompt_txt_list), (1,)).item()
]
prompt = self.prompt_cache[prompt_txt].to(device=self.device_torch, dtype=dtype)
neutral = self.prompt_cache[""].to(device=self.device_torch, dtype=dtype)
if prompt is None:
raise ValueError(f"Prompt {prompt_txt} is not in cache")
prompt_batch = train_tools.concat_prompt_embeddings(
prompt,
neutral,
self.train_config.batch_size,
)
noise_scheduler = self.sd.noise_scheduler
optimizer = self.optimizer
lr_scheduler = self.lr_scheduler
loss_function = torch.nn.MSELoss()
def get_noise_pred(p, n, gs, cts, dn):
return self.predict_noise(
latents=dn,
text_embeddings=train_tools.concat_prompt_embeddings(
p, # unconditional
n, # positive
self.train_config.batch_size,
),
timestep=cts,
guidance_scale=gs,
)
with torch.no_grad():
self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
)
self.optimizer.zero_grad()
# # ger a random number of steps
timesteps_to = torch.randint(
1, self.train_config.max_denoising_steps, (1,)
).item()
# get noise
noise = self.get_latent_noise(
pixel_height=self.rescale_config.from_resolution,
pixel_width=self.rescale_config.from_resolution,
).to(self.device_torch, dtype=dtype)
# get latents
latents = noise * self.sd.noise_scheduler.init_noise_sigma
latents = latents.to(self.device_torch, dtype=dtype)
#
# # predict without network
# assert self.network.is_active is False
# denoised_latents = self.diffuse_some_steps(
# latents, # pass simple noise latents
# prompt_batch,
# start_timesteps=0,
# total_timesteps=timesteps_to,
# guidance_scale=3,
# )
# noise_scheduler.set_timesteps(1000)
#
# current_timestep = noise_scheduler.timesteps[
# int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
# ]
current_timestep = 0
denoised_latents = latents
# get noise prediction at full scale
from_prediction = get_noise_pred(
prompt, neutral, 1, current_timestep, denoised_latents
)
reduced_from_prediction = self.reduce_size_fn(from_prediction).to("cpu", dtype=torch.float32)
# get noise prediction at reduced scale
to_denoised_latents = self.reduce_size_fn(denoised_latents)
# start gradient
optimizer.zero_grad()
self.network.multiplier = 1.0
with self.network:
assert self.network.is_active is True
to_prediction = get_noise_pred(
prompt, neutral, 1, current_timestep, to_denoised_latents
).to("cpu", dtype=torch.float32)
reduced_from_prediction.requires_grad = False
from_prediction.requires_grad = False
loss = loss_function(
reduced_from_prediction,
to_prediction,
)
loss_float = loss.item()
loss = loss.to(self.device_torch)
loss.backward()
optimizer.step()
lr_scheduler.step()
del (
reduced_from_prediction,
from_prediction,
to_denoised_latents,
to_prediction,
latents,
)
flush()
# reset network
self.network.multiplier = 1.0
loss_dict = OrderedDict(
{'loss': loss_float},
)
return loss_dict
# end hook_train_loop

View File

@@ -669,7 +669,7 @@ class TrainVAEProcess(BaseTrainProcess):
if self.writer is not None: if self.writer is not None:
# get avg loss # get avg loss
for key in log_losses: for key in log_losses:
log_losses[key] = sum(log_losses[key]) / len(log_losses[key]) log_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + 1e-6)
# if log_losses[key] > 0: # if log_losses[key] > 0:
self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num) self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num)
# reset log losses # reset log losses
@@ -678,9 +678,10 @@ class TrainVAEProcess(BaseTrainProcess):
self.step_num += 1 self.step_num += 1
# end epoch # end epoch
if self.writer is not None: if self.writer is not None:
eps = 1e-6
# get avg loss # get avg loss
for key in epoch_losses: for key in epoch_losses:
epoch_losses[key] = sum(log_losses[key]) / len(log_losses[key]) epoch_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + eps)
if epoch_losses[key] > 0: if epoch_losses[key] > 0:
self.writer.add_scalar(f"epoch loss/{key}", epoch_losses[key], epoch) self.writer.add_scalar(f"epoch loss/{key}", epoch_losses[key], epoch)
# reset epoch losses # reset epoch losses

View File

@@ -7,3 +7,4 @@ from .TrainVAEProcess import TrainVAEProcess
from .BaseMergeProcess import BaseMergeProcess from .BaseMergeProcess import BaseMergeProcess
from .TrainSliderProcess import TrainSliderProcess from .TrainSliderProcess import TrainSliderProcess
from .TrainLoRAHack import TrainLoRAHack from .TrainLoRAHack import TrainLoRAHack
from .TrainSDRescaleProcess import TrainSDRescaleProcess

View File

@@ -10,4 +10,6 @@ pyyaml
oyaml oyaml
tensorboard tensorboard
kornia kornia
invisible-watermark invisible-watermark
einops
accelerate

View File

@@ -5,6 +5,7 @@ class SaveConfig:
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.save_every: int = kwargs.get('save_every', 1000) self.save_every: int = kwargs.get('save_every', 1000)
self.dtype: str = kwargs.get('save_dtype', 'float16') self.dtype: str = kwargs.get('save_dtype', 'float16')
self.max_step_saves_to_keep: int = kwargs.get('max_step_saves_to_keep', 5)
class LogingConfig: class LogingConfig:
@@ -30,8 +31,16 @@ class SampleConfig:
class NetworkConfig: class NetworkConfig:
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.type: str = kwargs.get('type', 'lierla') self.type: str = kwargs.get('type', 'lora')
self.rank: int = kwargs.get('rank', 4) rank = kwargs.get('rank', None)
linear = kwargs.get('linear', None)
if rank is not None:
self.rank: int = rank # rank for backward compatibility
self.linear: int = rank
elif linear is not None:
self.rank: int = linear
self.linear: int = linear
self.conv: int = kwargs.get('conv', None)
self.alpha: float = kwargs.get('alpha', 1.0) self.alpha: float = kwargs.get('alpha', 1.0)
@@ -51,6 +60,7 @@ class TrainConfig:
self.noise_offset = kwargs.get('noise_offset', 0.0) self.noise_offset = kwargs.get('noise_offset', 0.0)
self.optimizer_params = kwargs.get('optimizer_params', {}) self.optimizer_params = kwargs.get('optimizer_params', {})
self.skip_first_sample = kwargs.get('skip_first_sample', False) self.skip_first_sample = kwargs.get('skip_first_sample', False)
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', False)
class ModelConfig: class ModelConfig:

31
toolkit/layers.py Normal file
View File

@@ -0,0 +1,31 @@
import torch
import torch.nn as nn
import numpy as np
class ReductionKernel(nn.Module):
# Tensorflow
def __init__(self, in_channels, kernel_size=2, dtype=torch.float32, device=None):
if device is None:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
super(ReductionKernel, self).__init__()
self.kernel_size = kernel_size
self.in_channels = in_channels
numpy_kernel = self.build_kernel()
self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
def build_kernel(self):
# tensorflow kernel is (height, width, in_channels, out_channels)
# pytorch kernel is (out_channels, in_channels, height, width)
kernel_size = self.kernel_size
channels = self.in_channels
kernel_shape = [channels, channels, kernel_size, kernel_size]
kernel = np.zeros(kernel_shape, np.float32)
kernel_value = 1.0 / (kernel_size * kernel_size)
for i in range(0, channels):
kernel[i, i, :, :] = kernel_value
return kernel
def forward(self, x):
return nn.functional.conv2d(x, self.kernel, stride=self.kernel_size, padding=0, groups=1)

View File

@@ -1,5 +1,8 @@
import json import json
from collections import OrderedDict from collections import OrderedDict
from safetensors import safe_open
from info import software_meta from info import software_meta
@@ -25,4 +28,10 @@ def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict:
parsed_meta[key] = json.loads(value) parsed_meta[key] = json.loads(value)
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
parsed_meta[key] = value parsed_meta[key] = value
return meta return parsed_meta
def load_metadata_from_safetensors(file_path: str) -> OrderedDict:
with safe_open(file_path, framework="pt") as f:
metadata = f.metadata()
return parse_metadata_from_safetensors(metadata)

View File

@@ -1,11 +1,18 @@
from typing import Union from typing import Union, OrderedDict
import sys import sys
import os import os
from safetensors.torch import save_file
from toolkit.paths import REPOS_ROOT from toolkit.paths import REPOS_ROOT
from toolkit.train_tools import get_torch_dtype
sys.path.append(REPOS_ROOT) sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco')) sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from leco import train_util from leco import train_util
import torch import torch
from library import model_util
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
class PromptEmbeds: class PromptEmbeds:
@@ -22,6 +29,12 @@ class PromptEmbeds:
self.text_embeds = args self.text_embeds = args
self.pooled_embeds = None self.pooled_embeds = None
def to(self, **kwargs):
self.text_embeds = self.text_embeds.to(**kwargs)
if self.pooled_embeds is not None:
self.pooled_embeds = self.pooled_embeds.to(**kwargs)
return self
class StableDiffusion: class StableDiffusion:
def __init__( def __init__(
@@ -61,3 +74,41 @@ class StableDiffusion:
self.tokenizer, self.text_encoder, prompt self.tokenizer, self.text_encoder, prompt
) )
) )
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
# todo see what logit scale is
if self.is_xl:
state_dict = {}
def update_sd(prefix, sd):
for k, v in sd.items():
key = prefix + k
v = v.detach().clone().to("cpu").to(get_torch_dtype(save_dtype))
state_dict[key] = v
# Convert the UNet model
update_sd("model.diffusion_model.", self.unet.state_dict())
# Convert the text encoders
update_sd("conditioner.embedders.0.transformer.", self.text_encoder[0].state_dict())
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(self.text_encoder[1].state_dict(), logit_scale)
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
# Convert the VAE
vae_dict = model_util.convert_vae_state_dict(self.vae.state_dict())
update_sd("first_stage_model.", vae_dict)
# Put together new checkpoint
key_count = len(state_dict.keys())
new_ckpt = {"state_dict": state_dict}
if model_util.is_safetensors(output_file):
save_file(state_dict, output_file)
else:
torch.save(new_ckpt, output_file, meta)
return key_count
else:
raise NotImplementedError("sdv1.x, sdv2.x is not implemented yet")

View File

@@ -2,6 +2,7 @@ import argparse
import json import json
import os import os
import time import time
from typing import TYPE_CHECKING
from diffusers import ( from diffusers import (
StableDiffusionPipeline, StableDiffusionPipeline,
@@ -21,8 +22,6 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel
import torch import torch
import re import re
from toolkit.stable_diffusion_model import PromptEmbeds
SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120 SCHEDULER_LINEAR_END = 0.0120
SCHEDULER_TIMESTEPS = 1000 SCHEDULER_TIMESTEPS = 1000
@@ -381,11 +380,16 @@ def apply_noise_offset(noise, noise_offset):
return noise return noise
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import PromptEmbeds
def concat_prompt_embeddings( def concat_prompt_embeddings(
unconditional: PromptEmbeds, unconditional: 'PromptEmbeds',
conditional: PromptEmbeds, conditional: 'PromptEmbeds',
n_imgs: int, n_imgs: int,
): ):
from toolkit.stable_diffusion_model import PromptEmbeds
text_embeds = torch.cat( text_embeds = torch.cat(
[unconditional.text_embeds, conditional.text_embeds] [unconditional.text_embeds, conditional.text_embeds]
).repeat_interleave(n_imgs, dim=0) ).repeat_interleave(n_imgs, dim=0)