diff --git a/jobs/process/BaseTrainProcess.py b/jobs/process/BaseTrainProcess.py index afcf9399..c80b9e79 100644 --- a/jobs/process/BaseTrainProcess.py +++ b/jobs/process/BaseTrainProcess.py @@ -1,10 +1,14 @@ +import os from collections import OrderedDict +from typing import ForwardRef + from jobs.process.BaseProcess import BaseProcess class BaseTrainProcess(BaseProcess): process_id: int config: OrderedDict + progress_bar: ForwardRef('tqdm') = None def __init__( self, @@ -13,8 +17,23 @@ class BaseTrainProcess(BaseProcess): config: OrderedDict ): super().__init__(process_id, job, config) + self.progress_bar = None + self.writer = self.job.writer + self.training_folder = self.get_conf('training_folder', self.job.training_folder) + self.save_root = os.path.join(self.training_folder, self.job.name) + self.step = 0 + self.first_step = 0 def run(self): + super().run() # implement in child class # be sure to call super().run() first pass + + # def print(self, message, **kwargs): + def print(self, *args): + if self.progress_bar is not None: + self.progress_bar.write(' '.join(map(str, args))) + self.progress_bar.update() + else: + print(*args) diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py new file mode 100644 index 00000000..eb3d008e --- /dev/null +++ b/jobs/process/TrainSliderProcess.py @@ -0,0 +1,609 @@ +# ref: +# - https://github.com/p1atdev/LECO/blob/main/train_lora.py +import time +from collections import OrderedDict +import os +from toolkit.kohya_model_util import load_vae +from toolkit.lora_special import LoRASpecialNetwork +from toolkit.paths import REPOS_ROOT +import sys + +sys.path.append(REPOS_ROOT) +sys.path.append(os.path.join(REPOS_ROOT, 'leco')) + +from diffusers import StableDiffusionPipeline + +from jobs.process import BaseTrainProcess +from toolkit.metadata import get_meta_for_safetensors +from toolkit.train_tools import get_torch_dtype +import gc + +import torch +from tqdm import tqdm + +from toolkit.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV, TRAINING_METHODS +from leco import train_util, model_util +from leco.prompt_util import PromptEmbedsCache, PromptEmbedsPair, ACTION_TYPES +from leco import debug_util + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class StableDiffusion: + def __init__(self, vae, tokenizer, text_encoder, unet, noise_scheduler): + self.vae = vae + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.unet = unet + self.noise_scheduler = noise_scheduler + + +class SaveConfig: + def __init__(self, **kwargs): + self.save_every: int = kwargs.get('save_every', 1000) + self.dtype: str = kwargs.get('save_dtype', 'float16') + + +class LogingConfig: + def __init__(self, **kwargs): + self.log_every: int = kwargs.get('log_every', 100) + self.verbose: bool = kwargs.get('verbose', False) + self.use_wandb: bool = kwargs.get('use_wandb', False) + + +class SampleConfig: + def __init__(self, **kwargs): + self.sample_every: int = kwargs.get('sample_every', 100) + self.width: int = kwargs.get('width', 512) + self.height: int = kwargs.get('height', 512) + self.prompts: list[str] = kwargs.get('prompts', []) + self.neg = kwargs.get('neg', False) + self.seed = kwargs.get('seed', 0) + self.walk_seed = kwargs.get('walk_seed', False) + self.guidance_scale = kwargs.get('guidance_scale', 7) + self.sample_steps = kwargs.get('sample_steps', 20) + + +class NetworkConfig: + def __init__(self, **kwargs): + self.type: str = kwargs.get('type', 'lierla') + self.rank: int = kwargs.get('rank', 4) + self.alpha: float = kwargs.get('alpha', 1.0) + + +class TrainConfig: + def __init__(self, **kwargs): + self.noise_scheduler: 'model_util.AVAILABLE_SCHEDULERS' = kwargs.get('noise_scheduler', 'ddpm') + self.steps: int = kwargs.get('steps', 1000) + self.lr = kwargs.get('lr', 1e-6) + self.optimizer = kwargs.get('optimizer', 'adamw') + self.lr_scheduler = kwargs.get('lr_scheduler', 'constant') + self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 50) + self.batch_size: int = kwargs.get('batch_size', 1) + self.dtype: str = kwargs.get('dtype', 'fp32') + self.xformers = kwargs.get('xformers', False) + self.train_unet = kwargs.get('train_unet', True) + self.train_text_encoder = kwargs.get('train_text_encoder', True) + + +class ModelConfig: + def __init__(self, **kwargs): + self.name_or_path: str = kwargs.get('name_or_path', None) + self.is_v2: bool = kwargs.get('is_v2', False) + self.is_v_pred: bool = kwargs.get('is_v_pred', False) + + if self.name_or_path is None: + raise ValueError('name_or_path must be specified') + + +class PromptSettingsOld: + def __init__(self, **kwargs): + self.target: str = kwargs.get('target', None) + self.positive = kwargs.get('positive', None) # if None, target will be used + self.unconditional = kwargs.get('unconditional', "") # default is "" + self.neutral = kwargs.get('neutral', None) # if None, unconditional will be used + self.action: ACTION_TYPES = kwargs.get('action', "erase") # default is "erase" + self.guidance_scale: float = kwargs.get('guidance_scale', 1.0) # default is 1.0 + self.resolution: int = kwargs.get('resolution', 512) # default is 512 + self.dynamic_resolution: bool = kwargs.get('dynamic_resolution', False) # default is False + self.batch_size: int = kwargs.get('batch_size', 1) # default is 1 + self.dynamic_crops: bool = kwargs.get('dynamic_crops', False) # default is False. only used when model is XL + + +class TrainSliderProcess(BaseTrainProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.step_num = 0 + self.start_step = 0 + self.device = self.get_conf('device', self.job.device) + self.device_torch = torch.device(self.device) + self.network_config = NetworkConfig(**self.get_conf('network', {})) + self.training_folder = self.get_conf('training_folder', self.job.training_folder) + self.train_config = TrainConfig(**self.get_conf('train', {})) + self.model_config = ModelConfig(**self.get_conf('model', {})) + self.save_config = SaveConfig(**self.get_conf('save', {})) + self.sample_config = SampleConfig(**self.get_conf('sample', {})) + self.logging_config = LogingConfig(**self.get_conf('logging', {})) + self.sd = None + + self.prompt_settings = self.get_prompt_settings() + + # added later + self.network = None + self.scheduler = None + self.is_flipped = False + + def flip_network(self): + for param in self.network.parameters(): + # apply opposite weight to the network + param.data = -param.data + self.is_flipped = not self.is_flipped + + def get_prompt_settings(self): + prompts = self.get_conf('prompts', required=True) + prompt_settings = [PromptSettingsOld(**prompt) for prompt in prompts] + # for i, prompt in enumerate(prompts): + # prompt_settings[i].fill_prompts(prompt) + + return prompt_settings + + def sample(self, step=None): + sample_folder = os.path.join(self.save_root, 'samples') + if not os.path.exists(sample_folder): + os.makedirs(sample_folder, exist_ok=True) + + self.network.eval() + + # save current seed state for training + rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + + original_device_dict = { + 'vae': self.sd.vae.device, + 'unet': self.sd.unet.device, + 'text_encoder': self.sd.text_encoder.device, + # 'tokenizer': self.sd.tokenizer.device, + } + + self.sd.vae.to(self.device_torch) + self.sd.unet.to(self.device_torch) + self.sd.text_encoder.to(self.device_torch) + # self.sd.tokenizer.to(self.device_torch) + # TODO add clip skip + + pipeline = StableDiffusionPipeline( + vae=self.sd.vae, + unet=self.sd.unet, + text_encoder=self.sd.text_encoder, + tokenizer=self.sd.tokenizer, + scheduler=self.sd.noise_scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + # disable progress bar + pipeline.set_progress_bar_config(disable=True) + + start_seed = self.sample_config.seed + current_seed = start_seed + + pipeline.to(self.device_torch) + with self.network: + with torch.no_grad(): + assert self.network.is_active + if self.logging_config.verbose: + print("network_state", { + 'is_active': self.network.is_active, + 'multiplier': self.network.multiplier, + }) + + for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}"): + raw_prompt = self.sample_config.prompts[i] + prompt = raw_prompt + neg = self.sample_config.neg + p_split = raw_prompt.split('--n') + if len(p_split) > 1: + prompt = p_split[0].strip() + neg = p_split[1].strip() + height = self.sample_config.height + width = self.sample_config.width + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + + if self.sample_config.walk_seed: + current_seed += i + + torch.manual_seed(current_seed) + torch.cuda.manual_seed(current_seed) + + img = pipeline( + prompt, + height=height, + width=width, + num_inference_steps=self.sample_config.sample_steps, + guidance_scale=self.sample_config.guidance_scale, + negative_prompt=neg, + ).images[0] + + step_num = '' + if step is not None: + # zero-pad 9 digits + step_num = f"_{str(step).zfill(9)}" + seconds_since_epoch = int(time.time()) + # zero-pad 2 digits + i_str = str(i).zfill(2) + filename = f"{seconds_since_epoch}{step_num}_{i_str}.png" + output_path = os.path.join(sample_folder, filename) + img.save(output_path) + + # clear pipeline and cache to reduce vram usage + del pipeline + torch.cuda.empty_cache() + + # restore training state + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + self.sd.vae.to(original_device_dict['vae']) + self.sd.unet.to(original_device_dict['unet']) + self.sd.text_encoder.to(original_device_dict['text_encoder']) + self.network.train() + # self.sd.tokenizer.to(original_device_dict['tokenizer']) + + def update_training_metadata(self): + self.add_meta(OrderedDict({"training_info": self.get_training_info()})) + + def get_training_info(self): + info = OrderedDict({ + 'step': self.step_num + }) + return info + + def save(self, step=None): + if not os.path.exists(self.save_root): + os.makedirs(self.save_root, exist_ok=True) + + step_num = '' + if step is not None: + # zeropad 9 digits + step_num = f"_{str(step).zfill(9)}" + + self.update_training_metadata() + filename = f'{self.job.name}{step_num}.safetensors' + file_path = os.path.join(self.save_root, filename) + # prepare meta + save_meta = get_meta_for_safetensors(self.meta, self.job.name) + self.network.save_weights( + file_path, + dtype=get_torch_dtype(self.save_config.dtype), + metadata=save_meta + ) + + self.print(f"Saved to {file_path}") + + def run(self): + super().run() + + dtype = get_torch_dtype(self.train_config.dtype) + + modules = DEFAULT_TARGET_REPLACE + loss = None + if self.network_config.type == "c3lier": + modules += UNET_TARGET_REPLACE_MODULE_CONV + + tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models( + self.model_config.name_or_path, + scheduler_name=self.train_config.noise_scheduler, + v2=self.model_config.is_v2, + v_pred=self.model_config.is_v_pred, + ) + # just for now or of we want to load a custom one + # put on cpu for now, we only need it when sampling + vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype) + vae.eval() + self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler) + + text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.eval() + + unet.to(self.device_torch, dtype=dtype) + if self.train_config.xformers: + unet.enable_xformers_memory_efficient_attention() + unet.requires_grad_(False) + unet.eval() + + self.network = LoRASpecialNetwork( + text_encoder=text_encoder, + unet=unet, + lora_dim=self.network_config.rank, + multiplier=1.0, + alpha=self.network_config.alpha + ) + + self.network.force_to(self.device_torch, dtype=dtype) + + self.network.apply_to( + text_encoder, + unet, + self.train_config.train_text_encoder, + self.train_config.train_unet + ) + + self.network.prepare_grad_etc(text_encoder, unet) + + optimizer_type = self.train_config.optimizer.lower() + # we call it something different than leco + if optimizer_type == "dadaptation": + optimizer_type = "dadaptadam" + optimizer_module = train_util.get_optimizer(optimizer_type) + optimizer = optimizer_module( + self.network.prepare_optimizer_params( + self.train_config.lr, self.train_config.lr, self.train_config.lr + ), + lr=self.train_config.lr + ) + lr_scheduler = train_util.get_lr_scheduler( + self.train_config.lr_scheduler, + optimizer, + max_iterations=self.train_config.steps, + lr_min=self.train_config.lr / 100, # not sure why leco did this, but ill do it to + ) + criteria = torch.nn.MSELoss() + + if self.logging_config.verbose: + print("Prompts") + for settings in self.prompt_settings: + print(settings) + + # debug + # debug_util.check_requires_grad(network) + # debug_util.check_training_mode(network) + + cache = PromptEmbedsCache() + prompt_pairs: list[PromptEmbedsPair] = [] + + with torch.no_grad(): + for settings in self.prompt_settings: + self.print(settings) + for prompt in [ + settings.target, + settings.positive, + settings.neutral, + settings.unconditional, + ]: + if cache[prompt] == None: + cache[prompt] = train_util.encode_prompts( + tokenizer, text_encoder, [prompt] + ) + + prompt_pairs.append( + PromptEmbedsPair( + criteria, + cache[settings.target], + cache[settings.positive], + cache[settings.unconditional], + cache[settings.neutral], + settings, + ) + ) + + # move to cpu to save vram + # tokenizer.to("cpu") + text_encoder.to("cpu") + flush() + + # sample first + self.print("Generating baseline samples before training") + self.sample(0) + + self.progress_bar = tqdm(range(self.train_config.steps)) + self.progress_bar = tqdm( + total=self.train_config.steps, + desc=self.job.name, + leave=True + ) + self.step_num = 0 + for step in range(self.train_config.steps): + with torch.no_grad(): + noise_scheduler.set_timesteps( + self.train_config.max_denoising_steps, device=self.device_torch + ) + + optimizer.zero_grad() + + prompt_pair: PromptEmbedsPair = prompt_pairs[ + torch.randint(0, len(prompt_pairs), (1,)).item() + ] + + # 1 ~ 49 random from 1 to 49 + timesteps_to = torch.randint( + 1, self.train_config.max_denoising_steps, (1,) + ).item() + + height, width = ( + prompt_pair.resolution, + prompt_pair.resolution, + ) + if prompt_pair.dynamic_resolution: + height, width = train_util.get_random_resolution_in_bucket( + prompt_pair.resolution + ) + + if self.logging_config.verbose: + self.print("guidance_scale:", prompt_pair.guidance_scale) + self.print("resolution:", prompt_pair.resolution) + self.print("dynamic_resolution:", prompt_pair.dynamic_resolution) + if prompt_pair.dynamic_resolution: + self.print("bucketed resolution:", (height, width)) + self.print("batch_size:", prompt_pair.batch_size) + + latents = train_util.get_initial_latents( + noise_scheduler, prompt_pair.batch_size, height, width, 1 + ).to(self.device_torch, dtype=dtype) + + with self.network: + assert self.network.is_active + # A little denoised one is returned + denoised_latents = train_util.diffusion( + unet, + noise_scheduler, + latents, # pass simple noise latents + train_util.concat_embeddings( + prompt_pair.unconditional, + prompt_pair.target, + prompt_pair.batch_size, + ), + start_timesteps=0, + total_timesteps=timesteps_to, + guidance_scale=3, + ) + + noise_scheduler.set_timesteps(1000) + + current_timestep = noise_scheduler.timesteps[ + int(timesteps_to * 1000 / self.train_config.max_denoising_steps) + ] + + # with network: Only empty LoRA is enabled outside with network : + positive_latents = train_util.predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + train_util.concat_embeddings( + prompt_pair.unconditional, + prompt_pair.positive, + prompt_pair.batch_size, + ), + guidance_scale=1, + ).to("cpu", dtype=torch.float32) + neutral_latents = train_util.predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + train_util.concat_embeddings( + prompt_pair.unconditional, + prompt_pair.neutral, + prompt_pair.batch_size, + ), + guidance_scale=1, + ).to("cpu", dtype=torch.float32) + unconditional_latents = train_util.predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + train_util.concat_embeddings( + prompt_pair.unconditional, + prompt_pair.unconditional, + prompt_pair.batch_size, + ), + guidance_scale=1, + ).to("cpu", dtype=torch.float32) + # if self.logging_config.verbose: + # self.print("positive_latents:", positive_latents[0, 0, :5, :5]) + # self.print("neutral_latents:", neutral_latents[0, 0, :5, :5]) + # self.print("unconditional_latents:", unconditional_latents[0, 0, :5, :5]) + + with self.network: + target_latents = train_util.predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + train_util.concat_embeddings( + prompt_pair.unconditional, + prompt_pair.target, + prompt_pair.batch_size, + ), + guidance_scale=1, + ).to("cpu", dtype=torch.float32) + + # if self.logging_config.verbose: + # self.print("target_latents:", target_latents[0, 0, :5, :5]) + + positive_latents.requires_grad = False + neutral_latents.requires_grad = False + unconditional_latents.requires_grad = False + + loss = prompt_pair.loss( + target_latents=target_latents, + positive_latents=positive_latents, + neutral_latents=neutral_latents, + unconditional_latents=unconditional_latents, + ) + loss_float = loss.item() + if self.train_config.optimizer.startswith('dadaptation'): + learning_rate = ( + optimizer.param_groups[0]["d"] * + optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = optimizer.param_groups[0]['lr'] + + self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e} loss: {loss.item():.3e}") + + loss.backward() + optimizer.step() + lr_scheduler.step() + + del ( + positive_latents, + neutral_latents, + unconditional_latents, + target_latents, + latents, + ) + flush() + + # don't do on first step + if self.step_num != self.start_step: + # pause progress bar + self.progress_bar.unpause() # makes it so doesn't track time + if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0: + # print above the progress bar + self.sample(self.step_num) + + if self.save_config.save_every and self.step_num % self.save_config.save_every == 0: + # print above the progress bar + self.print(f"Saving at step {self.step_num}") + self.save(self.step_num) + + if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0: + # log to tensorboard + if self.writer is not None: + # get avg loss + self.writer.add_scalar(f"loss", loss_float, self.step_num) + if self.train_config.optimizer.startswith('dadaptation'): + learning_rate = ( + optimizer.param_groups[0]["d"] * + optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = optimizer.param_groups[0]['lr'] + self.writer.add_scalar(f"lr", learning_rate, self.step_num) + self.progress_bar.refresh() + + # sets progress bar to match out step + self.progress_bar.update(step - self.progress_bar.n) + # end of step + self.step_num = step + + self.save() + + del ( + unet, + noise_scheduler, + loss, + optimizer, + self.network, + tokenizer, + text_encoder, + ) + + flush() diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index f2e0a518..1ec467f1 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -178,7 +178,6 @@ class TrainVAEProcess(BaseTrainProcess): self.device = self.get_conf('device', self.job.device) self.vae_path = self.get_conf('vae_path', required=True) self.datasets_objects = self.get_conf('datasets', required=True) - self.training_folder = self.get_conf('training_folder', self.job.training_folder) self.batch_size = self.get_conf('batch_size', 1, as_type=int) self.resolution = self.get_conf('resolution', 256, as_type=int) self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float) @@ -197,14 +196,10 @@ class TrainVAEProcess(BaseTrainProcess): self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float) self.critic_weight = self.get_conf('critic_weight', 1, as_type=float) self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float) - self.first_step = 0 self.blocks_to_train = self.get_conf('blocks_to_train', ['all']) - self.writer = self.job.writer self.torch_dtype = get_torch_dtype(self.dtype) - self.save_root = os.path.join(self.training_folder, self.job.name) self.vgg_19 = None - self.progress_bar = None self.style_weight_scalers = [] self.content_weight_scalers = [] @@ -254,13 +249,6 @@ class TrainVAEProcess(BaseTrainProcess): }) return info - def print(self, message, **kwargs): - if self.progress_bar is not None: - self.progress_bar.write(message, **kwargs) - self.progress_bar.update() - else: - print(message, **kwargs) - def load_datasets(self): if self.data_loader is None: print(f"Loading datasets") diff --git a/jobs/process/__init__.py b/jobs/process/__init__.py index 413aebce..357acc8d 100644 --- a/jobs/process/__init__.py +++ b/jobs/process/__init__.py @@ -5,3 +5,4 @@ from .BaseProcess import BaseProcess from .BaseTrainProcess import BaseTrainProcess from .TrainVAEProcess import TrainVAEProcess from .BaseMergeProcess import BaseMergeProcess +from .TrainSliderProcess import TrainSliderProcess diff --git a/requirements.txt b/requirements.txt index a8b4231f..ee597c3d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ flatten_json accelerator pyyaml oyaml -tensorboard \ No newline at end of file +tensorboard +kornia \ No newline at end of file diff --git a/toolkit/lora.py b/toolkit/lora.py new file mode 100644 index 00000000..9b3b65a6 --- /dev/null +++ b/toolkit/lora.py @@ -0,0 +1,238 @@ +# ref: +# - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py +# - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py +# - https://github.com/p1atdev/LECO/blob/main/lora.py + +import os +import math +from typing import Optional, List, Type, Set, Literal + +import torch +import torch.nn as nn +from diffusers import UNet2DConditionModel +from safetensors.torch import save_file + + +UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [ + "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2 +] +UNET_TARGET_REPLACE_MODULE_CONV = [ + "ResnetBlock2D", + "Downsample2D", + "Upsample2D", +] # locon, 3clier + +LORA_PREFIX_UNET = "lora_unet" + +DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER + +TRAINING_METHODS = Literal[ + "noxattn", # train all layers except x-attns and time_embed layers + "innoxattn", # train all layers except self attention layers + "selfattn", # ESD-u, train only self attention layers + "xattn", # ESD-x, train only x attention layers + "full", # train all layers + # "notime", + # "xlayer", + # "outxattn", + # "outsattn", + # "inxattn", + # "inmidsattn", + # "selflayer", +] + + +class LoRAModule(nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + self.lora_dim = lora_dim + + if org_module.__class__.__name__ == "Linear": + in_dim = org_module.in_features + out_dim = org_module.out_features + self.lora_down = nn.Linear(in_dim, lora_dim, bias=False) + self.lora_up = nn.Linear(lora_dim, out_dim, bias=False) + + elif org_module.__class__.__name__ == "Conv2d": # 一応 + in_dim = org_module.in_channels + out_dim = org_module.out_channels + + self.lora_dim = min(self.lora_dim, in_dim, out_dim) + if self.lora_dim != lora_dim: + print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = nn.Conv2d( + in_dim, self.lora_dim, kernel_size, stride, padding, bias=False + ) + self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().numpy() + alpha = lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + return ( + self.org_forward(x) + + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + ) + + +class LoRANetwork(nn.Module): + def __init__( + self, + unet: UNet2DConditionModel, + rank: int = 4, + multiplier: float = 1.0, + alpha: float = 1.0, + train_method: TRAINING_METHODS = "full", + ) -> None: + super().__init__() + + self.multiplier = multiplier + self.lora_dim = rank + self.alpha = alpha + + # LoRAのみ + self.module = LoRAModule + + # unetのloraを作る + self.unet_loras = self.create_modules( + LORA_PREFIX_UNET, + unet, + DEFAULT_TARGET_REPLACE, + self.lora_dim, + self.multiplier, + train_method=train_method, + ) + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + # assertion 名前の被りがないか確認しているようだ + lora_names = set() + for lora in self.unet_loras: + assert ( + lora.lora_name not in lora_names + ), f"duplicated lora name: {lora.lora_name}. {lora_names}" + lora_names.add(lora.lora_name) + + # 適用する + for lora in self.unet_loras: + lora.apply_to() + self.add_module( + lora.lora_name, + lora, + ) + + del unet + + torch.cuda.empty_cache() + + def create_modules( + self, + prefix: str, + root_module: nn.Module, + target_replace_modules: List[str], + rank: int, + multiplier: float, + train_method: TRAINING_METHODS, + ) -> list: + loras = [] + + for name, module in root_module.named_modules(): + if train_method == "noxattn": # Cross Attention と Time Embed 以外学習 + if "attn2" in name or "time_embed" in name: + continue + elif train_method == "innoxattn": # Cross Attention 以外学習 + if "attn2" in name: + continue + elif train_method == "selfattn": # Self Attention のみ学習 + if "attn1" not in name: + continue + elif train_method == "xattn": # Cross Attention のみ学習 + if "attn2" not in name: + continue + elif train_method == "full": # 全部学習 + pass + else: + raise NotImplementedError( + f"train_method: {train_method} is not implemented." + ) + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ in ["Linear", "Conv2d"]: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + print(f"{lora_name}") + lora = self.module( + lora_name, child_module, multiplier, rank, self.alpha + ) + loras.append(lora) + + return loras + + def prepare_optimizer_params(self): + all_params = [] + + if self.unet_loras: # 実質これしかない + params = [] + [params.extend(lora.parameters()) for lora in self.unet_loras] + param_data = {"params": params} + all_params.append(param_data) + + return all_params + + def save_weights(self, file, dtype=None, metadata: Optional[dict] = None): + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + for key in list(state_dict.keys()): + if not key.startswith("lora"): + # lora以外除外 + del state_dict[key] + + if os.path.splitext(file)[1] == ".safetensors": + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def __enter__(self): + for lora in self.unet_loras: + lora.multiplier = 1.0 + + def __exit__(self, exc_type, exc_value, tb): + for lora in self.unet_loras: + lora.multiplier = 0 diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py new file mode 100644 index 00000000..453de5ba --- /dev/null +++ b/toolkit/lora_special.py @@ -0,0 +1,226 @@ +import os +import sys +from typing import List + +import torch +from .paths import SD_SCRIPTS_ROOT + +sys.path.append(SD_SCRIPTS_ROOT) + +from networks.lora import LoRANetwork, LoRAModule, get_block_index + + +class LoRASpecialNetwork(LoRANetwork): + _multiplier: float = 1.0 + is_active: bool = False + + def __init__( + self, + text_encoder, + unet, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + conv_lora_dim=None, + conv_alpha=None, + block_dims=None, + block_alphas=None, + conv_block_dims=None, + conv_block_alphas=None, + modules_dim=None, + modules_alpha=None, + module_class=LoRAModule, + varbose=False, + ) -> None: + """ + LoRA network: すごく引数が多いが、パターンは以下の通り + 1. lora_dimとalphaを指定 + 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定 + 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない + 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する + 5. modules_dimとmodules_alphaを指定 (推論用) + """ + # call the parent of the parent we are replacing (LoRANetwork) init + super(LoRANetwork, self).__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + if modules_dim is not None: + print(f"create LoRA network from weights") + elif block_dims is not None: + print(f"create LoRA network from block_dims") + print( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + print(f"block_dims: {block_dims}") + print(f"block_alphas: {block_alphas}") + if conv_block_dims is not None: + print(f"conv_block_dims: {conv_block_dims}") + print(f"conv_block_alphas: {conv_block_alphas}") + else: + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + print( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + if self.conv_lora_dim is not None: + print( + f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + + # create module instances + def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: + prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + + dim = None + alpha = None + if modules_dim is not None: + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + elif is_unet and block_dims is not None: + block_idx = get_block_index(lora_name) + if is_linear or is_conv2d_1x1: + dim = block_dims[block_idx] + alpha = block_alphas[block_idx] + elif conv_block_dims is not None: + dim = conv_block_dims[block_idx] + alpha = conv_block_alphas[block_idx] + else: + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + if is_linear or is_conv2d_1x1 or ( + self.conv_lora_dim is not None or conv_block_dims is not None): + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + ) + loras.append(lora) + return loras, skipped + + self.text_encoder_loras, skipped_te = create_modules(False, text_encoder, + LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: + target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + self.unet_loras, skipped_un = create_modules(True, unet, target_modules) + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + skipped = skipped_te + skipped_un + if varbose and len(skipped) > 0: + print( + f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + print(f"\t{name}") + + self.up_lr_weight: List[float] = None + self.down_lr_weight: List[float] = None + self.mid_lr_weight: float = None + self.block_lr = False + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + # doesnt work on new diffusers. TODO make sure we are not missing something + # assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + @property + def multiplier(self): + return self._multiplier + + @multiplier.setter + def multiplier(self, value): + self._multiplier = value + self._update_lora_multiplier() + + def _update_lora_multiplier(self): + + if self.is_active: + if hasattr(self, 'unet_loras'): + for lora in self.unet_loras: + lora.multiplier = self._multiplier + if hasattr(self, 'text_encoder_loras'): + for lora in self.text_encoder_loras: + lora.multiplier = self._multiplier + else: + if hasattr(self, 'unet_loras'): + for lora in self.unet_loras: + lora.multiplier = 0 + if hasattr(self, 'text_encoder_loras'): + for lora in self.text_encoder_loras: + lora.multiplier = 0 + + def __enter__(self): + self.is_active = True + self._update_lora_multiplier() + + def __exit__(self, exc_type, exc_value, tb): + self.is_active = False + self._update_lora_multiplier() + + def force_to(self, device, dtype): + self.to(device, dtype) + loras = [] + if hasattr(self, 'unet_loras'): + loras += self.unet_loras + if hasattr(self, 'text_encoder_loras'): + loras += self.text_encoder_loras + for lora in loras: + lora.to(device, dtype) diff --git a/toolkit/losses.py b/toolkit/losses.py index f8c2855a..aded0764 100644 --- a/toolkit/losses.py +++ b/toolkit/losses.py @@ -83,3 +83,5 @@ class PatternLoss(torch.nn.Module): g_chan_loss = torch.abs(separated_chan_loss(g_chans) - separated_chan_loss(g_chans_target)) b_chan_loss = torch.abs(separated_chan_loss(b_chans) - separated_chan_loss(b_chans_target)) return (r_chan_loss + g_chan_loss + b_chan_loss) * 0.3333 + +