From 7165f2d25ac165ef1000a33f6fdff0fcf310e280 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 23 Jun 2024 20:46:48 +0000 Subject: [PATCH] Work to omprove pixart training --- jobs/process/BaseSDTrainProcess.py | 8 +++-- requirements.txt | 3 +- toolkit/config_modules.py | 2 ++ toolkit/ip_adapter.py | 2 +- toolkit/models/te_adapter.py | 52 +++++++++++++++++++++++++++--- toolkit/stable_diffusion_model.py | 16 +++++---- 6 files changed, 65 insertions(+), 18 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 7b69b283..c93730e5 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -589,16 +589,18 @@ class BaseSDTrainProcess(BaseTrainProcess): return latest_path def load_training_state_from_metadata(self, path): + meta = None # if path is folder, then it is diffusers if os.path.isdir(path): meta_path = os.path.join(path, 'aitk_meta.yaml') # load it - with open(meta_path, 'r') as f: - meta = yaml.load(f, Loader=yaml.FullLoader) + if os.path.exists(meta_path): + with open(meta_path, 'r') as f: + meta = yaml.load(f, Loader=yaml.FullLoader) else: meta = load_metadata_from_safetensors(path) # if 'training_info' in Orderdict keys - if 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None: + if meta is not None and 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None: self.step_num = meta['training_info']['step'] if 'epoch' in meta['training_info']: self.epoch_num = meta['training_info']['epoch'] diff --git a/requirements.txt b/requirements.txt index ce63f11b..347cc0e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,6 +23,5 @@ prodigyopt controlnet_aux==0.0.7 python-dotenv bitsandbytes -xformers hf_transfer -lpips \ No newline at end of file +lpips diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index c6a3245b..00dcfe77 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -372,6 +372,8 @@ class ModelConfig: # for text encoder quant. Only works with pixart currently self.text_encoder_bits = kwargs.get('text_encoder_bits', 8) # 16, 8, 4 + self.unet_path = kwargs.get("unet_path", None) + self.unet_sample_size = kwargs.get("unet_sample_size", None) class ReferenceDatasetConfig: diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 2b39e9ae..4e91e295 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -333,7 +333,7 @@ class IPAdapter(torch.nn.Module): if not self.config.train_image_encoder: # compile it print('Compiling image encoder') - torch.compile(self.image_encoder, fullgraph=True) + #torch.compile(self.image_encoder, fullgraph=True) self.input_size = self.image_encoder.config.image_size diff --git a/toolkit/models/te_adapter.py b/toolkit/models/te_adapter.py index a8676ec3..8daf6305 100644 --- a/toolkit/models/te_adapter.py +++ b/toolkit/models/te_adapter.py @@ -11,6 +11,7 @@ from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokeniz from toolkit import train_tools from toolkit.paths import REPOS_ROOT from toolkit.prompt_utils import PromptEmbeds +from diffusers import Transformer2DModel sys.path.append(REPOS_ROOT) @@ -176,6 +177,7 @@ class TEAdapter(torch.nn.Module): self.te_ref: weakref.ref = weakref.ref(te) self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer) self.adapter_modules = [] + is_pixart = sd.is_pixart if self.adapter_ref().config.text_encoder_arch == "t5": self.token_size = self.te_ref().config.d_model @@ -195,9 +197,28 @@ class TEAdapter(torch.nn.Module): } module_idx = 0 - attn_processors_list = list(sd.unet.attn_processors.keys()) - for name in sd.unet.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim'] + # init adapter modules + attn_procs = {} + unet_sd = sd.unet.state_dict() + attn_processor_keys = [] + if is_pixart: + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + attn_processor_keys.append(f"transformer_blocks.{i}.attn1") + + # cross attention + attn_processor_keys.append(f"transformer_blocks.{i}.attn2") + + else: + attn_processor_keys = list(sd.unet.attn_processors.keys()) + + attn_processor_names = [] + + blocks = [] + transformer_blocks = [] + for name in attn_processor_keys: + cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \ + sd.unet.config['cross_attention_dim'] if name.startswith("mid_block"): hidden_size = sd.unet.config['block_out_channels'][-1] elif name.startswith("up_blocks"): @@ -206,6 +227,8 @@ class TEAdapter(torch.nn.Module): elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = sd.unet.config['block_out_channels'][block_id] + elif name.startswith("transformer"): + hidden_size = sd.unet.config['cross_attention_dim'] else: # they didnt have this, but would lead to undefined below raise ValueError(f"unknown attn processor name: {name}") @@ -244,6 +267,13 @@ class TEAdapter(torch.nn.Module): "to_k_adapter.weight": to_k_adapter, "to_v_adapter.weight": to_v_adapter, } + + if self.sd_ref().is_pixart: + # pixart is much more sensitive + weights = { + "to_k_adapter.weight": weights["to_k_adapter.weight"] * 0.01, + "to_v_adapter.weight": weights["to_v_adapter.weight"] * 0.01, + } attn_procs[name] = TEAdapterAttnProcessor( hidden_size=hidden_size, @@ -256,8 +286,20 @@ class TEAdapter(torch.nn.Module): ) attn_procs[name].load_state_dict(weights) self.adapter_modules.append(attn_procs[name]) - sd.unet.set_attn_processor(attn_procs) - self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) + if self.sd_ref().is_pixart: + # we have to set them ourselves + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"] + module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"] + self.adapter_modules = torch.nn.ModuleList( + [ + transformer.transformer_blocks[i].attn2.processor for i in + range(len(transformer.transformer_blocks)) + ]) + else: + sd.unet.set_attn_processor(attn_procs) + self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) # make a getter to see if is active @property diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 21d83cfb..5e4ad77c 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -38,9 +38,9 @@ import torch from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \ StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ - StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \ + StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \ StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ - StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline + StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline import diffusers from diffusers import \ AutoencoderKL, \ @@ -185,7 +185,6 @@ class StableDiffusion: } if self.model_config.vae_path is not None: load_args['vae'] = load_vae(self.model_config.vae_path, dtype) - if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega: if self.custom_pipeline is not None: pipln = self.custom_pipeline @@ -331,11 +330,11 @@ class StableDiffusion: if self.model_config.is_pixart_sigma: # load the transformer only from the save transformer = Transformer2DModel.from_pretrained( - model_path, + model_path if self.model_config.unet_path is None else self.model_config.unet_path, torch_dtype=self.torch_dtype, subfolder='transformer' ) - pipe: PixArtAlphaPipeline = PixArtSigmaPipeline.from_pretrained( + pipe: PixArtSigmaPipeline = PixArtSigmaPipeline.from_pretrained( main_model_path, transformer=transformer, text_encoder=text_encoder, @@ -357,6 +356,9 @@ class StableDiffusion: device=self.device_torch, **load_args ).to(self.device_torch) + + if self.model_config.unet_sample_size is not None: + pipe.transformer.config.sample_size = self.model_config.unet_sample_size pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) flush() @@ -592,7 +594,7 @@ class StableDiffusion: **extra_args ) elif self.is_pixart: - pipeline = PixArtAlphaPipeline( + pipeline = PixArtSigmaPipeline( vae=self.vae, transformer=self.unet, text_encoder=self.text_encoder, @@ -1243,7 +1245,7 @@ class StableDiffusion: elif self.pipeline.transformer.config.sample_size == 32: aspect_ratio_bin = ASPECT_RATIO_256_BIN else: - raise ValueError("Invalid sample size") + raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}") orig_height, orig_width = height, width height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)