diff --git a/jobs/TrainJob.py b/jobs/TrainJob.py index d988f24e..1d5282ef 100644 --- a/jobs/TrainJob.py +++ b/jobs/TrainJob.py @@ -18,6 +18,7 @@ process_dict = { 'vae': 'TrainVAEProcess', 'slider': 'TrainSliderProcess', 'lora_hack': 'TrainLoRAHack', + 'rescale_sd': 'TrainSDRescaleProcess', } diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index f32d3fb1..32b55e22 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1,7 +1,10 @@ +import glob import time from collections import OrderedDict import os +from safetensors import safe_open + from toolkit.kohya_model_util import load_vae from toolkit.lora_special import LoRASpecialNetwork from toolkit.optimizer import get_optimizer @@ -14,7 +17,7 @@ sys.path.append(os.path.join(REPOS_ROOT, 'leco')) from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline from jobs.process import BaseTrainProcess -from toolkit.metadata import get_meta_for_safetensors +from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors from toolkit.train_tools import get_torch_dtype, apply_noise_offset import gc @@ -48,6 +51,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.model_config = ModelConfig(**self.get_conf('model', {})) self.save_config = SaveConfig(**self.get_conf('save', {})) self.sample_config = SampleConfig(**self.get_conf('sample', {})) + self.first_sample_config = SampleConfig(**self.get_conf('first_sample', {})) if 'first_sample' in self.config else self.sample_config self.logging_config = LogingConfig(**self.get_conf('logging', {})) self.optimizer = None self.lr_scheduler = None @@ -56,7 +60,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # added later self.network = None - def sample(self, step=None): + def sample(self, step=None, is_first=False): sample_folder = os.path.join(self.save_root, 'samples') if not os.path.exists(sample_folder): os.makedirs(sample_folder, exist_ok=True) @@ -112,7 +116,9 @@ class BaseSDTrainProcess(BaseTrainProcess): # disable progress bar pipeline.set_progress_bar_config(disable=True) - start_seed = self.sample_config.seed + sample_config = self.first_sample_config if is_first else self.sample_config + + start_seed = sample_config.seed start_multiplier = self.network.multiplier current_seed = start_seed @@ -127,14 +133,16 @@ class BaseSDTrainProcess(BaseTrainProcess): 'multiplier': self.network.multiplier, }) - for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}", + for i in tqdm(range(len(sample_config.prompts)), desc=f"Generating Samples - step: {step}", leave=False): - raw_prompt = self.sample_config.prompts[i] + raw_prompt = sample_config.prompts[i] - neg = self.sample_config.neg - multiplier = self.sample_config.network_multiplier + neg = sample_config.neg + multiplier = sample_config.network_multiplier p_split = raw_prompt.split('--') prompt = p_split[0].strip() + height = sample_config.height + width = sample_config.width if len(p_split) > 1: for split in p_split: @@ -145,13 +153,17 @@ class BaseSDTrainProcess(BaseTrainProcess): elif flag == 'm': # multiplier multiplier = float(content) + elif flag == 'w': + # multiplier + width = int(content) + elif flag == 'h': + # multiplier + height = int(content) - height = self.sample_config.height - width = self.sample_config.width height = max(64, height - height % 8) # round to divisible by 8 width = max(64, width - width % 8) # round to divisible by 8 - if self.sample_config.walk_seed: + if sample_config.walk_seed: current_seed += i if self.network is not None: @@ -159,14 +171,24 @@ class BaseSDTrainProcess(BaseTrainProcess): torch.manual_seed(current_seed) torch.cuda.manual_seed(current_seed) - img = pipeline( - prompt, - height=height, - width=width, - num_inference_steps=self.sample_config.sample_steps, - guidance_scale=self.sample_config.guidance_scale, - negative_prompt=neg, - ).images[0] + if self.sd.is_xl: + img = pipeline( + prompt, + height=height, + width=width, + num_inference_steps=sample_config.sample_steps, + guidance_scale=sample_config.guidance_scale, + negative_prompt=neg, + ).images[0] + else: + img = pipeline( + prompt, + height=height, + width=width, + num_inference_steps=sample_config.sample_steps, + guidance_scale=sample_config.guidance_scale, + negative_prompt=neg, + ).images[0] step_num = '' if step is not None: @@ -209,6 +231,24 @@ class BaseSDTrainProcess(BaseTrainProcess): }) return info + def clean_up_saves(self): + # remove old saves + # get latest saved step + if os.path.exists(self.save_root): + latest_file = None + # pattern is {job_name}_{zero_filles_step}.safetensors but NOT {job_name}.safetensors + pattern = f"{self.job.name}_*.safetensors" + files = glob.glob(os.path.join(self.save_root, pattern)) + if len(files) > self.save_config.max_step_saves_to_keep: + # remove all but the latest max_step_saves_to_keep + files.sort(key=os.path.getctime) + for file in files[:-self.save_config.max_step_saves_to_keep]: + self.print(f"Removing old save: {file}") + os.remove(file) + return latest_file + else: + return None + def save(self, step=None): if not os.path.exists(self.save_root): os.makedirs(self.save_root, exist_ok=True) @@ -231,9 +271,11 @@ class BaseSDTrainProcess(BaseTrainProcess): metadata=save_meta ) else: - # TODO handle dreambooth, fine tuning, etc - # will probably have to convert dict back to LDM - ValueError("Non network training is not currently supported") + self.sd.save( + file_path, + save_meta, + get_torch_dtype(self.save_config.dtype) + ) self.print(f"Saved to {file_path}") @@ -258,6 +300,7 @@ class BaseSDTrainProcess(BaseTrainProcess): ): if height is None and pixel_height is None: raise ValueError("height or pixel_height must be specified") + raise ValueError("height or pixel_height must be specified") if width is None and pixel_width is None: raise ValueError("width or pixel_width must be specified") if height is None: @@ -316,18 +359,47 @@ class BaseSDTrainProcess(BaseTrainProcess): if add_time_ids is None: add_time_ids = self.get_time_ids_from_latents(latents) # todo LECOs code looks like it is omitting noise_pred - noise_pred = train_util.predict_noise_xl( - self.sd.unet, - self.sd.noise_scheduler, + # noise_pred = train_util.predict_noise_xl( + # self.sd.unet, + # self.sd.noise_scheduler, + # timestep, + # latents, + # text_embeddings.text_embeds, + # text_embeddings.pooled_embeds, + # add_time_ids, + # guidance_scale=guidance_scale, + # guidance_rescale=guidance_rescale + # ) + latent_model_input = torch.cat([latents] * 2) + + latent_model_input = self.sd.noise_scheduler.scale_model_input(latent_model_input, timestep) + + added_cond_kwargs = { + "text_embeds": text_embeddings.pooled_embeds, + "time_ids": add_time_ids, + } + + # predict the noise residual + noise_pred = self.sd.unet( + latent_model_input, timestep, - latents, - text_embeddings.text_embeds, - text_embeddings.pooled_embeds, - add_time_ids, - guidance_scale=guidance_scale, - guidance_rescale=guidance_rescale + encoder_hidden_states=text_embeddings.text_embeds, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + guided_target = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond ) + # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 + # noise_pred = rescale_noise_cfg( + # noise_pred, noise_pred_text, guidance_rescale=guidance_rescale + # ) + + noise_pred = guided_target + else: noise_pred = train_util.predict_noise( self.sd.unet, @@ -366,6 +438,32 @@ class BaseSDTrainProcess(BaseTrainProcess): # return latents_steps return latents + def get_latest_save_path(self): + # get latest saved step + if os.path.exists(self.save_root): + latest_file = None + # pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors + pattern = f"{self.job.name}*.safetensors" + files = glob.glob(os.path.join(self.save_root, pattern)) + if len(files) > 0: + latest_file = max(files, key=os.path.getctime) + return latest_file + else: + return None + + def load_weights(self, path): + if self.network is not None: + self.network.load_weights(path) + meta = load_metadata_from_safetensors(path) + # if 'training_info' in Orderdict keys + if 'training_info' in meta and 'step' in meta['training_info']: + self.step_num = meta['training_info']['step'] + self.start_step = self.step_num + print(f"Found step {self.step_num} in metadata, starting from there") + + else: + print("load_weights not implemented for non-network models") + def run(self): super().run() @@ -407,20 +505,26 @@ class BaseSDTrainProcess(BaseTrainProcess): unet.to(self.device_torch, dtype=dtype) if self.train_config.xformers: unet.enable_xformers_memory_efficient_attention() + if self.train_config.gradient_checkpointing: + unet.enable_gradient_checkpointing() unet.requires_grad_(False) unet.eval() if self.network_config is not None: + conv = self.network_config.conv if self.network_config.conv is not None and self.network_config.conv > 0 else None self.network = LoRASpecialNetwork( text_encoder=text_encoder, unet=unet, - lora_dim=self.network_config.rank, + lora_dim=self.network_config.linear, multiplier=1.0, alpha=self.network_config.alpha, train_unet=self.train_config.train_unet, train_text_encoder=self.train_config.train_text_encoder, + conv_lora_dim=conv, + conv_alpha=self.network_config.alpha if conv is not None else None, ) + self.network.force_to(self.device_torch, dtype=dtype) self.network.apply_to( @@ -438,6 +542,15 @@ class BaseSDTrainProcess(BaseTrainProcess): default_lr=self.train_config.lr ) + latest_save_path = self.get_latest_save_path() + if latest_save_path is not None: + self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") + self.print(f"Loading from {latest_save_path}") + self.load_weights(latest_save_path) + self.network.multiplier = 1.0 + + + else: params = [] # assume dreambooth/finetune @@ -475,15 +588,17 @@ class BaseSDTrainProcess(BaseTrainProcess): self.print("Skipping first sample due to config setting") else: self.print("Generating baseline samples before training") - self.sample(0) + self.sample(0, is_first=True) self.progress_bar = tqdm( total=self.train_config.steps, desc=self.job.name, leave=True ) - self.step_num = 0 - for step in range(self.train_config.steps): + # set it to our current step in case it was updated from a load + self.progress_bar.update(self.step_num) + # self.step_num = 0 + for step in range(self.step_num, self.train_config.steps): # todo handle dataloader here maybe, not sure ### HOOK ### diff --git a/jobs/process/TrainSDRescaleProcess.py b/jobs/process/TrainSDRescaleProcess.py new file mode 100644 index 00000000..582d6ffa --- /dev/null +++ b/jobs/process/TrainSDRescaleProcess.py @@ -0,0 +1,278 @@ +# ref: +# - https://github.com/p1atdev/LECO/blob/main/train_lora.py +import time +from collections import OrderedDict +import os +from typing import Optional + +from safetensors.torch import load_file, save_file +from tqdm import tqdm + +from toolkit.config_modules import SliderConfig +from toolkit.layers import ReductionKernel +from toolkit.paths import REPOS_ROOT +import sys + +from toolkit.stable_diffusion_model import PromptEmbeds + +sys.path.append(REPOS_ROOT) +sys.path.append(os.path.join(REPOS_ROOT, 'leco')) +from toolkit.train_tools import get_torch_dtype, apply_noise_offset +import gc +from toolkit import train_tools + +import torch +from leco import train_util, model_util +from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class RescaleConfig: + def __init__( + self, + **kwargs + ): + self.from_resolution = kwargs.get('from_resolution', 512) + self.scale = kwargs.get('scale', 0.5) + self.prompt_file = kwargs.get('prompt_file', None) + self.prompt_tensors = kwargs.get('prompt_tensors', None) + self.to_resolution = kwargs.get('to_resolution', int(self.from_resolution * self.scale)) + + if self.prompt_file is None: + raise ValueError("prompt_file is required") + + +class PromptEmbedsCache: + prompts: dict[str, PromptEmbeds] = {} + + def __setitem__(self, __name: str, __value: PromptEmbeds) -> None: + self.prompts[__name] = __value + + def __getitem__(self, __name: str) -> Optional[PromptEmbeds]: + if __name in self.prompts: + return self.prompts[__name] + else: + return None + + +class TrainSDRescaleProcess(BaseSDTrainProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.step_num = 0 + self.start_step = 0 + self.device = self.get_conf('device', self.job.device) + self.device_torch = torch.device(self.device) + self.prompt_cache = PromptEmbedsCache() + self.rescale_config = RescaleConfig(**self.get_conf('rescale', required=True)) + self.reduce_size_fn = ReductionKernel( + in_channels=4, + kernel_size=int(self.rescale_config.from_resolution // self.rescale_config.to_resolution), + dtype=get_torch_dtype(self.train_config.dtype), + device=self.device_torch, + ) + self.prompt_txt_list = [] + + def before_model_load(self): + pass + + def hook_before_train_loop(self): + self.print(f"Loading prompt file from {self.rescale_config.prompt_file}") + + # read line by line from file + with open(self.rescale_config.prompt_file, 'r') as f: + self.prompt_txt_list = f.readlines() + # clean empty lines + self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0] + + self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..") + + cache = PromptEmbedsCache() + + # get encoded latents for our prompts + with torch.no_grad(): + if self.rescale_config.prompt_tensors is not None: + # check to see if it exists + if os.path.exists(self.rescale_config.prompt_tensors): + # load it. + self.print(f"Loading prompt tensors from {self.rescale_config.prompt_tensors}") + prompt_tensors = load_file(self.rescale_config.prompt_tensors, device='cpu') + # add them to the cache + for prompt_txt, prompt_tensor in prompt_tensors.items(): + if prompt_txt.startswith("te:"): + prompt = prompt_txt[3:] + # text_embeds + text_embeds = prompt_tensor + pooled_embeds = None + # find pool embeds + if f"pe:{prompt}" in prompt_tensors: + pooled_embeds = prompt_tensors[f"pe:{prompt}"] + + # make it + prompt_embeds = PromptEmbeds([text_embeds, pooled_embeds]) + cache[prompt] = prompt_embeds.to(device='cpu', dtype=torch.float32) + + if len(cache.prompts) == 0: + print("Prompt tensors not found. Encoding prompts..") + neutral = "" + # encode neutral + cache[neutral] = self.sd.encode_prompt(neutral) + for prompt in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False): + # build the cache + if cache[prompt] is None: + cache[prompt] = self.sd.encode_prompt(prompt).to(device="cpu", dtype=torch.float32) + + if self.rescale_config.prompt_tensors: + print(f"Saving prompt tensors to {self.rescale_config.prompt_tensors}") + state_dict = {} + for prompt_txt, prompt_embeds in cache.prompts.items(): + state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to("cpu", dtype=get_torch_dtype('fp16')) + if prompt_embeds.pooled_embeds is not None: + state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to("cpu", dtype=get_torch_dtype('fp16')) + save_file(state_dict, self.rescale_config.prompt_tensors) + + self.print("Encoding complete.") + + # move to cpu to save vram + # We don't need text encoder anymore, but keep it on cpu for sampling + # if text encoder is list + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + encoder.to("cpu") + else: + self.sd.text_encoder.to("cpu") + self.prompt_cache = cache + + flush() + # end hook_before_train_loop + + def hook_train_loop(self): + dtype = get_torch_dtype(self.train_config.dtype) + + # get random encoded prompt from cache + prompt_txt = self.prompt_txt_list[ + torch.randint(0, len(self.prompt_txt_list), (1,)).item() + ] + prompt = self.prompt_cache[prompt_txt].to(device=self.device_torch, dtype=dtype) + neutral = self.prompt_cache[""].to(device=self.device_torch, dtype=dtype) + if prompt is None: + raise ValueError(f"Prompt {prompt_txt} is not in cache") + + prompt_batch = train_tools.concat_prompt_embeddings( + prompt, + neutral, + self.train_config.batch_size, + ) + + noise_scheduler = self.sd.noise_scheduler + optimizer = self.optimizer + lr_scheduler = self.lr_scheduler + loss_function = torch.nn.MSELoss() + + def get_noise_pred(p, n, gs, cts, dn): + return self.predict_noise( + latents=dn, + text_embeddings=train_tools.concat_prompt_embeddings( + p, # unconditional + n, # positive + self.train_config.batch_size, + ), + timestep=cts, + guidance_scale=gs, + ) + + with torch.no_grad(): + self.sd.noise_scheduler.set_timesteps( + self.train_config.max_denoising_steps, device=self.device_torch + ) + + self.optimizer.zero_grad() + + # # ger a random number of steps + timesteps_to = torch.randint( + 1, self.train_config.max_denoising_steps, (1,) + ).item() + + # get noise + noise = self.get_latent_noise( + pixel_height=self.rescale_config.from_resolution, + pixel_width=self.rescale_config.from_resolution, + ).to(self.device_torch, dtype=dtype) + + # get latents + latents = noise * self.sd.noise_scheduler.init_noise_sigma + latents = latents.to(self.device_torch, dtype=dtype) + # + # # predict without network + # assert self.network.is_active is False + # denoised_latents = self.diffuse_some_steps( + # latents, # pass simple noise latents + # prompt_batch, + # start_timesteps=0, + # total_timesteps=timesteps_to, + # guidance_scale=3, + # ) + # noise_scheduler.set_timesteps(1000) + # + # current_timestep = noise_scheduler.timesteps[ + # int(timesteps_to * 1000 / self.train_config.max_denoising_steps) + # ] + + current_timestep = 0 + denoised_latents = latents + # get noise prediction at full scale + from_prediction = get_noise_pred( + prompt, neutral, 1, current_timestep, denoised_latents + ) + + reduced_from_prediction = self.reduce_size_fn(from_prediction).to("cpu", dtype=torch.float32) + + # get noise prediction at reduced scale + to_denoised_latents = self.reduce_size_fn(denoised_latents) + + # start gradient + optimizer.zero_grad() + self.network.multiplier = 1.0 + with self.network: + assert self.network.is_active is True + to_prediction = get_noise_pred( + prompt, neutral, 1, current_timestep, to_denoised_latents + ).to("cpu", dtype=torch.float32) + + reduced_from_prediction.requires_grad = False + from_prediction.requires_grad = False + + loss = loss_function( + reduced_from_prediction, + to_prediction, + ) + + loss_float = loss.item() + + loss = loss.to(self.device_torch) + + loss.backward() + optimizer.step() + lr_scheduler.step() + + del ( + reduced_from_prediction, + from_prediction, + to_denoised_latents, + to_prediction, + latents, + ) + flush() + + # reset network + self.network.multiplier = 1.0 + + loss_dict = OrderedDict( + {'loss': loss_float}, + ) + + return loss_dict + # end hook_train_loop diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index c62f4ef9..5eba8174 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -669,7 +669,7 @@ class TrainVAEProcess(BaseTrainProcess): if self.writer is not None: # get avg loss for key in log_losses: - log_losses[key] = sum(log_losses[key]) / len(log_losses[key]) + log_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + 1e-6) # if log_losses[key] > 0: self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num) # reset log losses @@ -678,9 +678,10 @@ class TrainVAEProcess(BaseTrainProcess): self.step_num += 1 # end epoch if self.writer is not None: + eps = 1e-6 # get avg loss for key in epoch_losses: - epoch_losses[key] = sum(log_losses[key]) / len(log_losses[key]) + epoch_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + eps) if epoch_losses[key] > 0: self.writer.add_scalar(f"epoch loss/{key}", epoch_losses[key], epoch) # reset epoch losses diff --git a/jobs/process/__init__.py b/jobs/process/__init__.py index 00a2fbbe..6329a213 100644 --- a/jobs/process/__init__.py +++ b/jobs/process/__init__.py @@ -7,3 +7,4 @@ from .TrainVAEProcess import TrainVAEProcess from .BaseMergeProcess import BaseMergeProcess from .TrainSliderProcess import TrainSliderProcess from .TrainLoRAHack import TrainLoRAHack +from .TrainSDRescaleProcess import TrainSDRescaleProcess \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index e099b7d5..df8cf9ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,6 @@ pyyaml oyaml tensorboard kornia -invisible-watermark \ No newline at end of file +invisible-watermark +einops +accelerate diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index e393f4a0..5e5c3623 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -5,6 +5,7 @@ class SaveConfig: def __init__(self, **kwargs): self.save_every: int = kwargs.get('save_every', 1000) self.dtype: str = kwargs.get('save_dtype', 'float16') + self.max_step_saves_to_keep: int = kwargs.get('max_step_saves_to_keep', 5) class LogingConfig: @@ -30,8 +31,16 @@ class SampleConfig: class NetworkConfig: def __init__(self, **kwargs): - self.type: str = kwargs.get('type', 'lierla') - self.rank: int = kwargs.get('rank', 4) + self.type: str = kwargs.get('type', 'lora') + rank = kwargs.get('rank', None) + linear = kwargs.get('linear', None) + if rank is not None: + self.rank: int = rank # rank for backward compatibility + self.linear: int = rank + elif linear is not None: + self.rank: int = linear + self.linear: int = linear + self.conv: int = kwargs.get('conv', None) self.alpha: float = kwargs.get('alpha', 1.0) @@ -51,6 +60,7 @@ class TrainConfig: self.noise_offset = kwargs.get('noise_offset', 0.0) self.optimizer_params = kwargs.get('optimizer_params', {}) self.skip_first_sample = kwargs.get('skip_first_sample', False) + self.gradient_checkpointing = kwargs.get('gradient_checkpointing', False) class ModelConfig: diff --git a/toolkit/layers.py b/toolkit/layers.py new file mode 100644 index 00000000..2d6aaecb --- /dev/null +++ b/toolkit/layers.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +import numpy as np + + +class ReductionKernel(nn.Module): + # Tensorflow + def __init__(self, in_channels, kernel_size=2, dtype=torch.float32, device=None): + if device is None: + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + super(ReductionKernel, self).__init__() + self.kernel_size = kernel_size + self.in_channels = in_channels + numpy_kernel = self.build_kernel() + self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype) + + def build_kernel(self): + # tensorflow kernel is (height, width, in_channels, out_channels) + # pytorch kernel is (out_channels, in_channels, height, width) + kernel_size = self.kernel_size + channels = self.in_channels + kernel_shape = [channels, channels, kernel_size, kernel_size] + kernel = np.zeros(kernel_shape, np.float32) + + kernel_value = 1.0 / (kernel_size * kernel_size) + for i in range(0, channels): + kernel[i, i, :, :] = kernel_value + return kernel + + def forward(self, x): + return nn.functional.conv2d(x, self.kernel, stride=self.kernel_size, padding=0, groups=1) diff --git a/toolkit/metadata.py b/toolkit/metadata.py index e5b1ce9e..6605feb3 100644 --- a/toolkit/metadata.py +++ b/toolkit/metadata.py @@ -1,5 +1,8 @@ import json from collections import OrderedDict + +from safetensors import safe_open + from info import software_meta @@ -25,4 +28,10 @@ def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict: parsed_meta[key] = json.loads(value) except json.decoder.JSONDecodeError: parsed_meta[key] = value - return meta + return parsed_meta + + +def load_metadata_from_safetensors(file_path: str) -> OrderedDict: + with safe_open(file_path, framework="pt") as f: + metadata = f.metadata() + return parse_metadata_from_safetensors(metadata) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 4541a2fe..9b1f6c09 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1,11 +1,18 @@ -from typing import Union +from typing import Union, OrderedDict import sys import os + +from safetensors.torch import save_file + from toolkit.paths import REPOS_ROOT +from toolkit.train_tools import get_torch_dtype + sys.path.append(REPOS_ROOT) sys.path.append(os.path.join(REPOS_ROOT, 'leco')) from leco import train_util import torch +from library import model_util +from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl class PromptEmbeds: @@ -22,6 +29,12 @@ class PromptEmbeds: self.text_embeds = args self.pooled_embeds = None + def to(self, **kwargs): + self.text_embeds = self.text_embeds.to(**kwargs) + if self.pooled_embeds is not None: + self.pooled_embeds = self.pooled_embeds.to(**kwargs) + return self + class StableDiffusion: def __init__( @@ -61,3 +74,41 @@ class StableDiffusion: self.tokenizer, self.text_encoder, prompt ) ) + + def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): + # todo see what logit scale is + if self.is_xl: + + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + v = v.detach().clone().to("cpu").to(get_torch_dtype(save_dtype)) + state_dict[key] = v + + # Convert the UNet model + update_sd("model.diffusion_model.", self.unet.state_dict()) + + # Convert the text encoders + update_sd("conditioner.embedders.0.transformer.", self.text_encoder[0].state_dict()) + + text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(self.text_encoder[1].state_dict(), logit_scale) + update_sd("conditioner.embedders.1.model.", text_enc2_dict) + + # Convert the VAE + vae_dict = model_util.convert_vae_state_dict(self.vae.state_dict()) + update_sd("first_stage_model.", vae_dict) + + # Put together new checkpoint + key_count = len(state_dict.keys()) + new_ckpt = {"state_dict": state_dict} + + if model_util.is_safetensors(output_file): + save_file(state_dict, output_file) + else: + torch.save(new_ckpt, output_file, meta) + + return key_count + else: + raise NotImplementedError("sdv1.x, sdv2.x is not implemented yet") diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 8fa98a67..9ac9f31c 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -2,6 +2,7 @@ import argparse import json import os import time +from typing import TYPE_CHECKING from diffusers import ( StableDiffusionPipeline, @@ -21,8 +22,6 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel import torch import re -from toolkit.stable_diffusion_model import PromptEmbeds - SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_END = 0.0120 SCHEDULER_TIMESTEPS = 1000 @@ -381,11 +380,16 @@ def apply_noise_offset(noise, noise_offset): return noise +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import PromptEmbeds + + def concat_prompt_embeddings( - unconditional: PromptEmbeds, - conditional: PromptEmbeds, + unconditional: 'PromptEmbeds', + conditional: 'PromptEmbeds', n_imgs: int, ): + from toolkit.stable_diffusion_model import PromptEmbeds text_embeds = torch.cat( [unconditional.text_embeds, conditional.text_embeds] ).repeat_interleave(n_imgs, dim=0)