diff --git a/jobs/BaseJob.py b/jobs/BaseJob.py index 3d0129df..f29a29d1 100644 --- a/jobs/BaseJob.py +++ b/jobs/BaseJob.py @@ -17,6 +17,7 @@ class BaseJob: raise ValueError('config is required') self.config = config['config'] + self.raw_config = config self.job = config['job'] self.name = self.get_conf('name', required=True) if 'meta' in config: diff --git a/jobs/TrainJob.py b/jobs/TrainJob.py index c8b4e4b9..d988f24e 100644 --- a/jobs/TrainJob.py +++ b/jobs/TrainJob.py @@ -1,3 +1,4 @@ +import json import os from jobs import BaseJob @@ -6,7 +7,7 @@ from collections import OrderedDict from typing import List from jobs.process import BaseExtractProcess, TrainFineTuneProcess from datetime import datetime - +import yaml from toolkit.paths import REPOS_ROOT import sys @@ -16,6 +17,7 @@ sys.path.append(REPOS_ROOT) process_dict = { 'vae': 'TrainVAEProcess', 'slider': 'TrainSliderProcess', + 'lora_hack': 'TrainLoRAHack', } @@ -37,6 +39,13 @@ class TrainJob(BaseJob): # loads the processes from the config self.load_processes(process_dict) + def save_training_config(self): + timestamp = datetime.now().strftime('%Y%m%d-%H%M%S') + os.makedirs(self.training_folder, exist_ok=True) + save_dif = os.path.join(self.training_folder, f'run_config_{timestamp}.yaml') + with open(save_dif, 'w') as f: + yaml.dump(self.raw_config, f) + def run(self): super().run() print("") diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 79774f8f..5e9f9405 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -2,7 +2,6 @@ import time from collections import OrderedDict import os -from leco.train_util import predict_noise from toolkit.kohya_model_util import load_vae from toolkit.lora_special import LoRASpecialNetwork from toolkit.optimizer import get_optimizer @@ -12,7 +11,7 @@ import sys sys.path.append(REPOS_ROOT) sys.path.append(os.path.join(REPOS_ROOT, 'leco')) -from diffusers import StableDiffusionPipeline +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline from jobs.process import BaseTrainProcess from toolkit.metadata import get_meta_for_safetensors @@ -24,6 +23,7 @@ from tqdm import tqdm from leco import train_util, model_util from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig +from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds def flush(): @@ -35,15 +35,6 @@ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 -class StableDiffusion: - def __init__(self, vae, tokenizer, text_encoder, unet, noise_scheduler): - self.vae = vae - self.tokenizer = tokenizer - self.text_encoder = text_encoder - self.unet = unet - self.noise_scheduler = noise_scheduler - - class BaseSDTrainProcess(BaseTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict): super().__init__(process_id, job, config) @@ -80,26 +71,44 @@ class BaseSDTrainProcess(BaseTrainProcess): original_device_dict = { 'vae': self.sd.vae.device, 'unet': self.sd.unet.device, - 'text_encoder': self.sd.text_encoder.device, # 'tokenizer': self.sd.tokenizer.device, } + # handle sdxl text encoder + if isinstance(self.sd.text_encoder, list): + for encoder, i in zip(self.sd.text_encoder, range(len(self.sd.text_encoder))): + original_device_dict[f'text_encoder_{i}'] = encoder.device + encoder.to(self.device_torch) + else: + original_device_dict['text_encoder'] = self.sd.text_encoder.device + self.sd.text_encoder.to(self.device_torch) + self.sd.vae.to(self.device_torch) self.sd.unet.to(self.device_torch) - self.sd.text_encoder.to(self.device_torch) + # self.sd.text_encoder.to(self.device_torch) # self.sd.tokenizer.to(self.device_torch) # TODO add clip skip - - pipeline = StableDiffusionPipeline( - vae=self.sd.vae, - unet=self.sd.unet, - text_encoder=self.sd.text_encoder, - tokenizer=self.sd.tokenizer, - scheduler=self.sd.noise_scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) + if self.sd.is_xl: + pipeline = StableDiffusionXLPipeline( + vae=self.sd.vae, + unet=self.sd.unet, + text_encoder=self.sd.text_encoder[0], + text_encoder_2=self.sd.text_encoder[1], + tokenizer=self.sd.tokenizer[0], + tokenizer_2=self.sd.tokenizer[1], + scheduler=self.sd.noise_scheduler, + ) + else: + pipeline = StableDiffusionPipeline( + vae=self.sd.vae, + unet=self.sd.unet, + text_encoder=self.sd.text_encoder, + tokenizer=self.sd.tokenizer, + scheduler=self.sd.noise_scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) # disable progress bar pipeline.set_progress_bar_config(disable=True) @@ -118,7 +127,8 @@ class BaseSDTrainProcess(BaseTrainProcess): 'multiplier': self.network.multiplier, }) - for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}", leave=False): + for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}", + leave=False): raw_prompt = self.sample_config.prompts[i] neg = self.sample_config.neg @@ -180,7 +190,11 @@ class BaseSDTrainProcess(BaseTrainProcess): self.sd.vae.to(original_device_dict['vae']) self.sd.unet.to(original_device_dict['unet']) - self.sd.text_encoder.to(original_device_dict['text_encoder']) + if isinstance(self.sd.text_encoder, list): + for encoder, i in zip(self.sd.text_encoder, range(len(self.sd.text_encoder))): + encoder.to(original_device_dict[f'text_encoder_{i}']) + else: + self.sd.text_encoder.to(original_device_dict['text_encoder']) if self.network is not None: self.network.train() self.network.multiplier = start_multiplier @@ -267,23 +281,90 @@ class BaseSDTrainProcess(BaseTrainProcess): # return loss return 0.0 - # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 - def diffuse_some_steps( + def get_time_ids_from_latents(self, latents): + bs, ch, h, w = list(latents.shape) + + height = h * VAE_SCALE_FACTOR + width = w * VAE_SCALE_FACTOR + + dtype = get_torch_dtype(self.train_config.dtype) + + if self.sd.is_xl: + prompt_ids = train_util.get_add_time_ids( + height, + width, + dynamic_crops=False, # look into this + dtype=dtype, + ).to(self.device_torch, dtype=dtype) + return train_util.concat_embeddings( + prompt_ids, prompt_ids, bs + ) + else: + return None + + def predict_noise( self, latents: torch.FloatTensor, - text_embeddings: torch.FloatTensor, - total_timesteps: int = 1000, - start_timesteps=0, + text_embeddings: PromptEmbeds, + timestep: int, + guidance_scale=7.5, + guidance_rescale=0.7, + add_time_ids=None, **kwargs, ): - - for timestep in tqdm(self.sd.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False): - noise_pred = train_util.predict_noise( - self.sd.unet, self.sd.noise_scheduler, timestep, latents, text_embeddings, **kwargs + if self.sd.is_xl: + if add_time_ids is None: + add_time_ids = self.get_time_ids_from_latents(latents) + # todo LECOs code looks like it is omitting noise_pred + noise_pred = train_util.predict_noise_xl( + self.sd.unet, + self.sd.noise_scheduler, + timestep, + latents, + text_embeddings.text_embeds, + text_embeddings.pooled_embeds, + add_time_ids, + guidance_scale=guidance_scale, + guidance_rescale=guidance_rescale ) # compute the previous noisy sample x_t -> x_t-1 latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample + else: + noise_pred = train_util.predict_noise( + self.sd.unet, + self.sd.noise_scheduler, + timestep, + latents, + text_embeddings.text_embeds, + guidance_scale=guidance_scale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample + return latents + + # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 + def diffuse_some_steps( + self, + latents: torch.FloatTensor, + text_embeddings: PromptEmbeds, + total_timesteps: int = 1000, + start_timesteps=0, + guidance_scale=1, + add_time_ids=None, + **kwargs, + ): + + for timestep in tqdm(self.sd.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False): + latents = self.predict_noise( + latents, + text_embeddings, + timestep, + guidance_scale=guidance_scale, + add_time_ids=add_time_ids, + **kwargs, + ) # return latents_steps return latents @@ -296,20 +377,35 @@ class BaseSDTrainProcess(BaseTrainProcess): dtype = get_torch_dtype(self.train_config.dtype) - tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models( - self.model_config.name_or_path, - scheduler_name=self.train_config.noise_scheduler, - v2=self.model_config.is_v2, - v_pred=self.model_config.is_v_pred, - ) + if self.model_config.is_xl: + tokenizer, text_encoders, unet, noise_scheduler = model_util.load_models_xl( + self.model_config.name_or_path, + scheduler_name=self.train_config.noise_scheduler, + weight_dtype=dtype, + ) + + for text_encoder in text_encoders: + text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + + text_encoder = text_encoders + else: + tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models( + self.model_config.name_or_path, + scheduler_name=self.train_config.noise_scheduler, + v2=self.model_config.is_v2, + v_pred=self.model_config.is_v_pred, + ) + + text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.eval() + # just for now or of we want to load a custom one # put on cpu for now, we only need it when sampling vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype) vae.eval() - self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler) - - text_encoder.to(self.device_torch, dtype=dtype) - text_encoder.eval() + self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler, is_xl=self.model_config.is_xl) unet.to(self.device_torch, dtype=dtype) if self.train_config.xformers: @@ -323,7 +419,9 @@ class BaseSDTrainProcess(BaseTrainProcess): unet=unet, lora_dim=self.network_config.rank, multiplier=1.0, - alpha=self.network_config.alpha + alpha=self.network_config.alpha, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, ) self.network.force_to(self.device_torch, dtype=dtype) @@ -376,8 +474,11 @@ class BaseSDTrainProcess(BaseTrainProcess): self.hook_before_train_loop() # sample first - self.print("Generating baseline samples before training") - self.sample(0) + if self.train_config.skip_first_sample: + self.print("Skipping first sample due to config setting") + else: + self.print("Generating baseline samples before training") + self.sample(0) self.progress_bar = tqdm( total=self.train_config.steps, diff --git a/jobs/process/TrainLoRAHack.py b/jobs/process/TrainLoRAHack.py new file mode 100644 index 00000000..a3fb118d --- /dev/null +++ b/jobs/process/TrainLoRAHack.py @@ -0,0 +1,76 @@ +# ref: +# - https://github.com/p1atdev/LECO/blob/main/train_lora.py +import time +from collections import OrderedDict +import os + +from toolkit.config_modules import SliderConfig +from toolkit.paths import REPOS_ROOT +import sys + +sys.path.append(REPOS_ROOT) +sys.path.append(os.path.join(REPOS_ROOT, 'leco')) +from toolkit.train_tools import get_torch_dtype, apply_noise_offset +import gc + +import torch +from leco import train_util, model_util +from leco.prompt_util import PromptEmbedsCache +from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class LoRAHack: + def __init__(self, **kwargs): + self.type = kwargs.get('type', 'suppression') + + +class TrainLoRAHack(BaseSDTrainProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.hack_config = LoRAHack(**self.get_conf('hack', {})) + + def hook_before_train_loop(self): + # we don't need text encoder so move it to cpu + self.sd.text_encoder.to("cpu") + flush() + # end hook_before_train_loop + + if self.hack_config.type == 'suppression': + # set all params to self.current_suppression + params = self.network.parameters() + for param in params: + # get random noise for each param + noise = torch.randn_like(param) - 0.5 + # apply noise to param + param.data = noise * 0.001 + + + def supress_loop(self): + dtype = get_torch_dtype(self.train_config.dtype) + + + loss_dict = OrderedDict( + {'sup': 0.0} + ) + # increase noise + for param in self.network.parameters(): + # get random noise for each param + noise = torch.randn_like(param) - 0.5 + # apply noise to param + param.data = param.data + noise * 0.001 + + + + return loss_dict + + def hook_train_loop(self): + if self.hack_config.type == 'suppression': + return self.supress_loop() + else: + raise NotImplementedError(f'unknown hack type: {self.hack_config.type}') + # end hook_train_loop diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index 2bed4b29..fddcecb1 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -3,19 +3,22 @@ import time from collections import OrderedDict import os +from typing import Optional from toolkit.config_modules import SliderConfig from toolkit.paths import REPOS_ROOT import sys +from toolkit.stable_diffusion_model import PromptEmbeds + sys.path.append(REPOS_ROOT) sys.path.append(os.path.join(REPOS_ROOT, 'leco')) from toolkit.train_tools import get_torch_dtype, apply_noise_offset import gc +from toolkit import train_tools import torch from leco import train_util, model_util -from leco.prompt_util import PromptEmbedsCache from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion @@ -29,7 +32,6 @@ def flush(): gc.collect() - class EncodedPromptPair: def __init__( self, @@ -54,6 +56,19 @@ class EncodedPromptPair: self.weight = weight +class PromptEmbedsCache: # 使いまわしたいので + prompts: dict[str, PromptEmbeds] = {} + + def __setitem__(self, __name: str, __value: PromptEmbeds) -> None: + self.prompts[__name] = __value + + def __getitem__(self, __name: str) -> Optional[PromptEmbeds]: + if __name in self.prompts: + return self.prompts[__name] + else: + return None + + class EncodedAnchor: def __init__( self, @@ -89,19 +104,17 @@ class TrainSliderProcess(BaseSDTrainProcess): with torch.no_grad(): neutral = "" for target in self.slider_config.targets: + # build the cache + for prompt in [ + target.target_class, + target.positive, + target.negative, + neutral # empty neutral + ]: + if cache[prompt] is None: + cache[prompt] = self.sd.encode_prompt(prompt) for resolution in self.slider_config.resolutions: width, height = resolution - # build the cache - for prompt in [ - target.target_class, - target.positive, - target.negative, - neutral # empty neutral - ]: - if cache[prompt] == None: - cache[prompt] = train_util.encode_prompts( - self.sd.tokenizer, self.sd.text_encoder, [prompt] - ) only_erase = len(target.positive.strip()) == 0 only_enhance = len(target.negative.strip()) == 0 @@ -184,9 +197,7 @@ class TrainSliderProcess(BaseSDTrainProcess): anchor.neg_prompt # empty neutral ]: if cache[prompt] == None: - cache[prompt] = train_util.encode_prompts( - self.sd.tokenizer, self.sd.text_encoder, [prompt] - ) + cache[prompt] = self.sd.encode_prompt(prompt) anchor_pairs += [ EncodedAnchor( @@ -198,7 +209,12 @@ class TrainSliderProcess(BaseSDTrainProcess): # move to cpu to save vram # We don't need text encoder anymore, but keep it on cpu for sampling - self.sd.text_encoder.to("cpu") + # if text encoder is list + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + encoder.to("cpu") + else: + self.sd.text_encoder.to("cpu") self.prompt_cache = cache self.prompt_pairs = prompt_pairs self.anchor_pairs = anchor_pairs @@ -220,6 +236,7 @@ class TrainSliderProcess(BaseSDTrainProcess): negative = prompt_pair.negative positive = prompt_pair.positive weight = prompt_pair.weight + multiplier = prompt_pair.multiplier unet = self.sd.unet noise_scheduler = self.sd.noise_scheduler @@ -227,8 +244,20 @@ class TrainSliderProcess(BaseSDTrainProcess): lr_scheduler = self.lr_scheduler loss_function = torch.nn.MSELoss() + def get_noise_pred(p, n): + return self.predict_noise( + latents=denoised_latents, + text_embeddings=train_tools.concat_prompt_embeddings( + p, # unconditional + n, # positive + self.train_config.batch_size, + ), + timestep=current_timestep, + guidance_scale=1, + ) + # set network multiplier - self.network.multiplier = prompt_pair.multiplier + self.network.multiplier = multiplier with torch.no_grad(): self.sd.noise_scheduler.set_timesteps( @@ -254,9 +283,10 @@ class TrainSliderProcess(BaseSDTrainProcess): with self.network: assert self.network.is_active + self.network.multiplier = multiplier denoised_latents = self.diffuse_some_steps( latents, # pass simple noise latents - train_util.concat_embeddings( + train_tools.concat_prompt_embeddings( positive, # unconditional target_class, # target self.train_config.batch_size, @@ -272,43 +302,11 @@ class TrainSliderProcess(BaseSDTrainProcess): int(timesteps_to * 1000 / self.train_config.max_denoising_steps) ] - # with network: 0 weight LoRA is enabled outside "with network:" - positive_latents = train_util.predict_noise( # positive_latents - unet, - noise_scheduler, - current_timestep, - denoised_latents, - train_util.concat_embeddings( - positive, # unconditional - negative, # positive - self.train_config.batch_size, - ), - guidance_scale=1, - ).to("cpu", dtype=torch.float32) - neutral_latents = train_util.predict_noise( - unet, - noise_scheduler, - current_timestep, - denoised_latents, - train_util.concat_embeddings( - positive, # unconditional - neutral, # neutral - self.train_config.batch_size, - ), - guidance_scale=1, - ).to("cpu", dtype=torch.float32) - unconditional_latents = train_util.predict_noise( - unet, - noise_scheduler, - current_timestep, - denoised_latents, - train_util.concat_embeddings( - positive, # unconditional - positive, # unconditional - self.train_config.batch_size, - ), - guidance_scale=1, - ).to("cpu", dtype=torch.float32) + positive_latents = get_noise_pred(positive, negative) + + neutral_latents = get_noise_pred(positive, neutral) + + unconditional_latents = get_noise_pred(positive, positive) anchor_loss = None if len(self.anchor_pairs) > 0: @@ -317,51 +315,19 @@ class TrainSliderProcess(BaseSDTrainProcess): torch.randint(0, len(self.anchor_pairs), (1,)).item() ] with torch.no_grad(): - anchor_target_noise = train_util.predict_noise( - unet, - noise_scheduler, - current_timestep, - denoised_latents, - train_util.concat_embeddings( - anchor.prompt, - anchor.neg_prompt, - self.train_config.batch_size, - ), - guidance_scale=1, - ).to("cpu", dtype=torch.float32) + anchor_target_noise = get_noise_pred(anchor.prompt, anchor.neg_prompt) with self.network: # anchor whatever weight prompt pair is using pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0 self.network.multiplier = anchor.multiplier * pos_nem_mult - anchor_pred_noise = train_util.predict_noise( - unet, - noise_scheduler, - current_timestep, - denoised_latents, - train_util.concat_embeddings( - anchor.prompt, - anchor.neg_prompt, - self.train_config.batch_size, - ), - guidance_scale=1, - ).to("cpu", dtype=torch.float32) + + anchor_pred_noise = get_noise_pred(anchor.prompt, anchor.neg_prompt) self.network.multiplier = prompt_pair.multiplier with self.network: self.network.multiplier = prompt_pair.multiplier - target_latents = train_util.predict_noise( - unet, - noise_scheduler, - current_timestep, - denoised_latents, - train_util.concat_embeddings( - positive, # unconditional - target_class, # target - self.train_config.batch_size, - ), - guidance_scale=1, - ).to("cpu", dtype=torch.float32) + target_latents = get_noise_pred(positive, target_class) # if self.logging_config.verbose: # self.print("target_latents:", target_latents[0, 0, :5, :5]) diff --git a/jobs/process/__init__.py b/jobs/process/__init__.py index 357acc8d..00a2fbbe 100644 --- a/jobs/process/__init__.py +++ b/jobs/process/__init__.py @@ -6,3 +6,4 @@ from .BaseTrainProcess import BaseTrainProcess from .TrainVAEProcess import TrainVAEProcess from .BaseMergeProcess import BaseMergeProcess from .TrainSliderProcess import TrainSliderProcess +from .TrainLoRAHack import TrainLoRAHack diff --git a/requirements.txt b/requirements.txt index ee597c3d..e099b7d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ accelerator pyyaml oyaml tensorboard -kornia \ No newline at end of file +kornia +invisible-watermark \ No newline at end of file diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a94de093..e393f4a0 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -50,12 +50,14 @@ class TrainConfig: self.train_text_encoder = kwargs.get('train_text_encoder', True) self.noise_offset = kwargs.get('noise_offset', 0.0) self.optimizer_params = kwargs.get('optimizer_params', {}) + self.skip_first_sample = kwargs.get('skip_first_sample', False) class ModelConfig: def __init__(self, **kwargs): self.name_or_path: str = kwargs.get('name_or_path', None) self.is_v2: bool = kwargs.get('is_v2', False) + self.is_xl: bool = kwargs.get('is_xl', False) self.is_v_pred: bool = kwargs.get('is_v_pred', False) if self.name_or_path is None: diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 453de5ba..14d5cb07 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -1,8 +1,10 @@ import os import sys -from typing import List +from typing import List, Optional, Dict, Type, Union import torch +from transformers import CLIPTextModel + from .paths import SD_SCRIPTS_ROOT sys.path.append(SD_SCRIPTS_ROOT) @@ -14,26 +16,40 @@ class LoRASpecialNetwork(LoRANetwork): _multiplier: float = 1.0 is_active: bool = False + NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 + + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + + # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER + LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" + LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" + def __init__( self, - text_encoder, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], unet, - multiplier=1.0, - lora_dim=4, - alpha=1, - dropout=None, - rank_dropout=None, - module_dropout=None, - conv_lora_dim=None, - conv_alpha=None, - block_dims=None, - block_alphas=None, - conv_block_dims=None, - conv_block_alphas=None, - modules_dim=None, - modules_alpha=None, - module_class=LoRAModule, - varbose=False, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + block_dims: Optional[List[int]] = None, + block_alphas: Optional[List[float]] = None, + conv_block_dims: Optional[List[int]] = None, + conv_block_alphas: Optional[List[float]] = None, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + module_class: Type[object] = LoRAModule, + varbose: Optional[bool] = False, + train_text_encoder: Optional[bool] = True, + train_unet: Optional[bool] = True, ) -> None: """ LoRA network: すごく引数が多いが、パターンは以下の通り @@ -75,8 +91,21 @@ class LoRASpecialNetwork(LoRANetwork): f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances - def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: - prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER + def create_modules( + is_unet: bool, + text_encoder_idx: Optional[int], # None, 1, 2 + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_UNET + if is_unet + else ( + self.LORA_PREFIX_TEXT_ENCODER + if text_encoder_idx is None + else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) + ) + ) loras = [] skipped = [] for name, module in root_module.named_modules(): @@ -92,11 +121,14 @@ class LoRASpecialNetwork(LoRANetwork): dim = None alpha = None + if modules_dim is not None: + # モジュール指定あり if lora_name in modules_dim: dim = modules_dim[lora_name] alpha = modules_alpha[lora_name] elif is_unet and block_dims is not None: + # U-Netでblock_dims指定あり block_idx = get_block_index(lora_name) if is_linear or is_conv2d_1x1: dim = block_dims[block_idx] @@ -105,6 +137,7 @@ class LoRASpecialNetwork(LoRANetwork): dim = conv_block_dims[block_idx] alpha = conv_block_alphas[block_idx] else: + # 通常、すべて対象とする if is_linear or is_conv2d_1x1: dim = self.lora_dim alpha = self.alpha @@ -113,6 +146,7 @@ class LoRASpecialNetwork(LoRANetwork): alpha = self.conv_alpha if dim is None or dim == 0: + # skipした情報を出力 if is_linear or is_conv2d_1x1 or ( self.conv_lora_dim is not None or conv_block_dims is not None): skipped.append(lora_name) @@ -131,8 +165,25 @@ class LoRASpecialNetwork(LoRANetwork): loras.append(lora) return loras, skipped - self.text_encoder_loras, skipped_te = create_modules(False, text_encoder, + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras = [] + skipped_te = [] + if train_text_encoder: + for i, text_encoder in enumerate(text_encoders): + if len(text_encoders) > 1: + index = i + 1 + print(f"create LoRA for Text Encoder {index}:") + else: + index = None + print(f"create LoRA for Text Encoder:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights @@ -140,7 +191,11 @@ class LoRASpecialNetwork(LoRANetwork): if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 - self.unet_loras, skipped_un = create_modules(True, unet, target_modules) + if train_unet: + self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) + else: + self.unet_loras = [] + skipped_un = [] print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un @@ -159,8 +214,7 @@ class LoRASpecialNetwork(LoRANetwork): # assertion names = set() for lora in self.text_encoder_loras + self.unet_loras: - # doesnt work on new diffusers. TODO make sure we are not missing something - # assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) def save_weights(self, file, dtype, metadata): diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py new file mode 100644 index 00000000..4541a2fe --- /dev/null +++ b/toolkit/stable_diffusion_model.py @@ -0,0 +1,63 @@ +from typing import Union +import sys +import os +from toolkit.paths import REPOS_ROOT +sys.path.append(REPOS_ROOT) +sys.path.append(os.path.join(REPOS_ROOT, 'leco')) +from leco import train_util +import torch + + +class PromptEmbeds: + text_embeds: torch.FloatTensor + pooled_embeds: Union[torch.FloatTensor, None] + + def __init__(self, args) -> None: + if isinstance(args, list) or isinstance(args, tuple): + # xl + self.text_embeds = args[0] + self.pooled_embeds = args[1] + else: + # sdv1.x, sdv2.x + self.text_embeds = args + self.pooled_embeds = None + + +class StableDiffusion: + def __init__( + self, + vae, + tokenizer, + text_encoder, + unet, + noise_scheduler, + is_xl=False + ): + # text encoder has a list of 2 for xl + self.vae = vae + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.unet = unet + self.noise_scheduler = noise_scheduler + self.is_xl = is_xl + + def encode_prompt(self, prompt, num_images_per_prompt=1) -> PromptEmbeds: + prompt = prompt + # if it is not a list, make it one + if not isinstance(prompt, list): + prompt = [prompt] + if self.is_xl: + return PromptEmbeds( + train_util.encode_prompts_xl( + self.tokenizer, + self.text_encoder, + prompt, + num_images_per_prompt=num_images_per_prompt, + ) + ) + else: + return PromptEmbeds( + train_util.encode_prompts( + self.tokenizer, self.text_encoder, prompt + ) + ) diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 2082811c..8fa98a67 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -21,6 +21,8 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel import torch import re +from toolkit.stable_diffusion_model import PromptEmbeds + SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_END = 0.0120 SCHEDULER_TIMESTEPS = 1000 @@ -377,3 +379,19 @@ def apply_noise_offset(noise, noise_offset): return noise noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device) return noise + + +def concat_prompt_embeddings( + unconditional: PromptEmbeds, + conditional: PromptEmbeds, + n_imgs: int, +): + text_embeds = torch.cat( + [unconditional.text_embeds, conditional.text_embeds] + ).repeat_interleave(n_imgs, dim=0) + pooled_embeds = None + if unconditional.pooled_embeds is not None and conditional.pooled_embeds is not None: + pooled_embeds = torch.cat( + [unconditional.pooled_embeds, conditional.pooled_embeds] + ).repeat_interleave(n_imgs, dim=0) + return PromptEmbeds([text_embeds, pooled_embeds])