From 2478554c95e0390dda38db7b08ae429687fe4ffa Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 17 Feb 2024 10:06:57 -0700 Subject: [PATCH] Bug fixes. Added IP adapter training for Pixart --- extensions_built_in/sd_trainer/SDTrainer.py | 14 +- toolkit/ip_adapter.py | 129 +++++++++++---- toolkit/models/vd_adapter.py | 171 ++++++++++++++++++-- toolkit/stable_diffusion_model.py | 13 +- 4 files changed, 278 insertions(+), 49 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index ab257209..ba74286a 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -346,8 +346,9 @@ class SDTrainer(BaseSDTrainProcess): print("Prior loss is nan") prior_loss = None else: - # prior_loss = prior_loss.mean([1, 2, 3]) - loss = loss + prior_loss + prior_loss = prior_loss.mean([1, 2, 3]) + # loss = loss + prior_loss + # loss = loss + prior_loss # loss = loss + prior_loss loss = loss.mean([1, 2, 3]) if prior_loss is not None: @@ -731,6 +732,15 @@ class SDTrainer(BaseSDTrainProcess): # self.network.multiplier = 0.0 self.sd.unet.eval() + if self.adapter is not None and isinstance(self.adapter, IPAdapter): + # we need to remove the image embeds from the prompt + embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach() + end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens + embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :] + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.clone().detach() + unconditional_embeds.text_embeds = unconditional_embeds.text_embeds[:, :end_pos] + if unconditional_embeds is not None: unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 1751f3e2..c3e2d33b 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -4,6 +4,7 @@ import torch import sys from PIL import Image +from diffusers import Transformer2DModel from torch.nn import Parameter from torch.nn.modules.module import T from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection @@ -79,6 +80,10 @@ class CustomIPAttentionProcessor(IPAttnProcessor2_0): hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) + if is_active: + # since we are removing tokens, we need to adjust the sequence length + sequence_length = sequence_length - self.num_tokens + if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be @@ -90,6 +95,9 @@ class CustomIPAttentionProcessor(IPAttnProcessor2_0): query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + # will be none if disabled if not is_active: ip_hidden_states = None @@ -120,9 +128,13 @@ class CustomIPAttentionProcessor(IPAttnProcessor2_0): # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + try: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + except Exception as e: + print(e) + raise e hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -235,7 +247,7 @@ class IPAdapter(torch.nn.Module): print(f"could not load image processor from {adapter_config.image_encoder_path}") self.clip_image_processor = ConvNextImageProcessor( size=512, - image_mean=[0.485,0.456,0.406], + image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225], ) self.image_encoder = ConvNextV2ForImageClassification.from_pretrained( @@ -299,6 +311,7 @@ class IPAdapter(torch.nn.Module): raise ValueError(f"unknown image processor size: {self.clip_image_processor.size}") self.current_scale = 1.0 self.is_active = True + is_pixart = sd.is_pixart if adapter_config.type == 'ip': # ip-adapter image_proj_model = ImageProjModel( @@ -310,14 +323,22 @@ class IPAdapter(torch.nn.Module): heads = 12 if not sd.is_xl else 20 dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280 embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else \ - self.image_encoder.config.hidden_sizes[-1] + self.image_encoder.config.hidden_sizes[-1] image_encoder_state_dict = self.image_encoder.state_dict() # max_seq_len = CLIP tokens + CLS token max_seq_len = 257 if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: # clip - max_seq_len = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) + max_seq_len = int( + image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) + + output_dim = sd.unet.config['cross_attention_dim'] + + if is_pixart: + heads = 20 + dim = 4096 + output_dim = 4096 # ip-adapter-plus image_proj_model = Resampler( @@ -328,7 +349,7 @@ class IPAdapter(torch.nn.Module): num_queries=self.config.num_tokens if self.config.num_tokens > 0 else max_seq_len, embedding_dim=embedding_dim, max_seq_len=max_seq_len, - output_dim=sd.unet.config['cross_attention_dim'], + output_dim=output_dim, ff_mult=4 ) elif adapter_config.type == 'ipz': @@ -373,8 +394,21 @@ class IPAdapter(torch.nn.Module): # init adapter modules attn_procs = {} unet_sd = sd.unet.state_dict() - for name in sd.unet.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim'] + attn_processor_keys = [] + if is_pixart: + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + attn_processor_keys.append(f"transformer_blocks.{i}.attn1") + + # cross attention + attn_processor_keys.append(f"transformer_blocks.{i}.attn2") + + else: + attn_processor_keys = list(sd.unet.attn_processors.keys()) + + for name in attn_processor_keys: + cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else \ + sd.unet.config['cross_attention_dim'] if name.startswith("mid_block"): hidden_size = sd.unet.config['block_out_channels'][-1] elif name.startswith("up_blocks"): @@ -383,6 +417,8 @@ class IPAdapter(torch.nn.Module): elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = sd.unet.config['block_out_channels'][block_id] + elif name.startswith("transformer"): + hidden_size = sd.unet.config['cross_attention_dim'] else: # they didnt have this, but would lead to undefined below raise ValueError(f"unknown attn processor name: {name}") @@ -402,14 +438,35 @@ class IPAdapter(torch.nn.Module): num_tokens=self.config.num_tokens, adapter=self ) + if self.sd_ref().is_pixart: + # pixart is much more sensitive + weights = { + "to_k_ip.weight": weights["to_k_ip.weight"] * 0.01, + "to_v_ip.weight": weights["to_v_ip.weight"] * 0.01, + } + attn_procs[name].load_state_dict(weights) - sd.unet.set_attn_processor(attn_procs) - adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) + if self.sd_ref().is_pixart: + # we have to set them ourselves + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"] + module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"] + self.adapter_modules = torch.nn.ModuleList( + [ + transformer.transformer_blocks[i].attn1.processor for i in + range(len(transformer.transformer_blocks)) + ] + [ + transformer.transformer_blocks[i].attn2.processor for i in + range(len(transformer.transformer_blocks)) + ]) + else: + sd.unet.set_attn_processor(attn_procs) + self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) sd.adapter = self self.unet_ref: weakref.ref = weakref.ref(sd.unet) self.image_proj_model = image_proj_model - self.adapter_modules = adapter_modules # load the weights if we have some if self.config.name_or_path: loaded_state_dict = load_ip_adapter_model( @@ -473,9 +530,10 @@ class IPAdapter(torch.nn.Module): def set_scale(self, scale): self.current_scale = scale - for attn_processor in self.sd_ref().unet.attn_processors.values(): - if isinstance(attn_processor, CustomIPAttentionProcessor): - attn_processor.scale = scale + if not self.sd_ref().is_pixart: + for attn_processor in self.sd_ref().unet.attn_processors.values(): + if isinstance(attn_processor, CustomIPAttentionProcessor): + attn_processor.scale = scale # @torch.no_grad() # def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]], @@ -554,7 +612,7 @@ class IPAdapter(torch.nn.Module): if self.clip_noise_zero: tensors_0_1 = torch.rand_like(tensors_0_1).detach() noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device, - dtype=get_torch_dtype(self.sd_ref().dtype)) + dtype=get_torch_dtype(self.sd_ref().dtype)) tensors_0_1 = tensors_0_1 * noise_scale else: tensors_0_1 = torch.zeros_like(tensors_0_1).detach() @@ -675,7 +733,6 @@ class IPAdapter(torch.nn.Module): embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1) return embeddings - def train(self: T, mode: bool = True) -> T: if self.config.train_image_encoder: self.image_encoder.train(mode) @@ -721,18 +778,22 @@ class IPAdapter(torch.nn.Module): raise ValueError(f"unknown shape: {current_shape}") except RuntimeError as e: print(e) - print(f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way") + print( + f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way") if len(current_shape) == 1: current_img_proj_state_dict[key][:current_shape[0]] = value[:current_shape[0]] elif len(current_shape) == 2: - current_img_proj_state_dict[key][:current_shape[0], :current_shape[1]] = value[:current_shape[0], :current_shape[1]] + current_img_proj_state_dict[key][:current_shape[0], :current_shape[1]] = value[ + :current_shape[0], + :current_shape[1]] elif len(current_shape) == 3: - current_img_proj_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]] + current_img_proj_state_dict[key][:current_shape[0], :current_shape[1], + :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]] elif len(current_shape) == 4: current_img_proj_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2], :current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2], - :current_shape[3]] + :current_shape[3]] else: raise ValueError(f"unknown shape: {current_shape}") print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}") @@ -763,16 +824,24 @@ class IPAdapter(torch.nn.Module): print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}") except RuntimeError as e: print(e) - print(f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way") + print( + f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way") - if(len(current_shape) == 1): + if (len(current_shape) == 1): current_ip_adapter_state_dict[key][:current_shape[0]] = value[:current_shape[0]] - elif(len(current_shape) == 2): - current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1]] = value[:current_shape[0], :current_shape[1]] - elif(len(current_shape) == 3): - current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]] - elif(len(current_shape) == 4): - current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2], :current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2], :current_shape[3]] + elif (len(current_shape) == 2): + current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1]] = value[ + :current_shape[ + 0], + :current_shape[ + 1]] + elif (len(current_shape) == 3): + current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], + :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]] + elif (len(current_shape) == 4): + current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2], + :current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2], + :current_shape[3]] else: raise ValueError(f"unknown shape: {current_shape}") print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}") @@ -781,7 +850,6 @@ class IPAdapter(torch.nn.Module): current_ip_adapter_state_dict[key] = value self.adapter_modules.load_state_dict(current_ip_adapter_state_dict) - def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): strict = False if 'ip_adapter' in state_dict: @@ -801,7 +869,6 @@ class IPAdapter(torch.nn.Module): # we are loading pure clip weights. self.image_encoder.load_state_dict(state_dict, strict=strict) - def enable_gradient_checkpointing(self): if hasattr(self.image_encoder, "enable_gradient_checkpointing"): self.image_encoder.enable_gradient_checkpointing() diff --git a/toolkit/models/vd_adapter.py b/toolkit/models/vd_adapter.py index f8a6dd03..a83e154b 100644 --- a/toolkit/models/vd_adapter.py +++ b/toolkit/models/vd_adapter.py @@ -6,16 +6,103 @@ import torch.nn.functional as F import weakref from typing import Union, TYPE_CHECKING +from diffusers import Transformer2DModel from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection from toolkit.paths import REPOS_ROOT sys.path.append(REPOS_ROOT) -from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0 if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion from toolkit.custom_adapter import CustomAdapter +class AttnProcessor2_0(torch.nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states class VisionDirectAdapterAttnProcessor(nn.Module): r""" @@ -31,7 +118,7 @@ class VisionDirectAdapterAttnProcessor(nn.Module): """ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, adapter=None, - adapter_hidden_size=None): + adapter_hidden_size=None, has_bias=False, **kwargs): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): @@ -44,12 +131,13 @@ class VisionDirectAdapterAttnProcessor(nn.Module): self.cross_attention_dim = cross_attention_dim self.scale = scale - self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=False) - self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=False) + self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) + self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) @property def is_active(self): return self.adapter_ref().is_active + # return False @property def unconditional_embeds(self): @@ -175,6 +263,7 @@ class VisionDirectAdapter(torch.nn.Module): vision_model: Union[CLIPVisionModelWithProjection], ): super(VisionDirectAdapter, self).__init__() + is_pixart = sd.is_pixart self.adapter_ref: weakref.ref = weakref.ref(adapter) self.sd_ref: weakref.ref = weakref.ref(sd) self.vision_model_ref: weakref.ref = weakref.ref(vision_model) @@ -184,8 +273,22 @@ class VisionDirectAdapter(torch.nn.Module): # init adapter modules attn_procs = {} unet_sd = sd.unet.state_dict() - for name in sd.unet.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim'] + + attn_processor_keys = [] + if is_pixart: + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + + attn_processor_keys.append(f"transformer_blocks.{i}.attn1") + + # cross attention + attn_processor_keys.append(f"transformer_blocks.{i}.attn2") + + else: + attn_processor_keys = list(sd.unet.attn_processors.keys()) + + for name in attn_processor_keys: + cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else sd.unet.config['cross_attention_dim'] if name.startswith("mid_block"): hidden_size = sd.unet.config['block_out_channels'][-1] elif name.startswith("up_blocks"): @@ -194,6 +297,8 @@ class VisionDirectAdapter(torch.nn.Module): elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = sd.unet.config['block_out_channels'][block_id] + elif name.startswith("transformer"): + hidden_size = sd.unet.config['cross_attention_dim'] else: # they didnt have this, but would lead to undefined below raise ValueError(f"unknown attn processor name: {name}") @@ -203,6 +308,12 @@ class VisionDirectAdapter(torch.nn.Module): layer_name = name.split(".processor")[0] to_k_adapter = unet_sd[layer_name + ".to_k.weight"] to_v_adapter = unet_sd[layer_name + ".to_v.weight"] + # if is_pixart: + # to_k_bias = unet_sd[layer_name + ".to_k.bias"] + # to_v_bias = unet_sd[layer_name + ".to_v.bias"] + # else: + # to_k_bias = None + # to_v_bias = None # add zero padding to the adapter if to_k_adapter.shape[1] < self.token_size: @@ -220,29 +331,65 @@ class VisionDirectAdapter(torch.nn.Module): ], dim=1 ) + # if is_pixart: + # to_k_bias = torch.cat([ + # to_k_bias, + # torch.zeros(self.token_size - to_k_adapter.shape[1]).to( + # to_k_adapter.device, dtype=to_k_adapter.dtype) + # ], + # dim=0 + # ) + # to_v_bias = torch.cat([ + # to_v_bias, + # torch.zeros(self.token_size - to_v_adapter.shape[1]).to( + # to_k_adapter.device, dtype=to_k_adapter.dtype) + # ], + # dim=0 + # ) elif to_k_adapter.shape[1] > self.token_size: to_k_adapter = to_k_adapter[:, :self.token_size] to_v_adapter = to_v_adapter[:, :self.token_size] + # if is_pixart: + # to_k_bias = to_k_bias[:self.token_size] + # to_v_bias = to_v_bias[:self.token_size] else: to_k_adapter = to_k_adapter to_v_adapter = to_v_adapter + # if is_pixart: + # to_k_bias = to_k_bias + # to_v_bias = to_v_bias - # todo resize to the TE hidden size weights = { - "to_k_adapter.weight": to_k_adapter, - "to_v_adapter.weight": to_v_adapter, + "to_k_adapter.weight": to_k_adapter * 0.01, + "to_v_adapter.weight": to_v_adapter * 0.01, } + # if is_pixart: + # weights["to_k_adapter.bias"] = to_k_bias + # weights["to_v_adapter.bias"] = to_v_bias attn_procs[name] = VisionDirectAdapterAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, adapter=self, - adapter_hidden_size=self.token_size + adapter_hidden_size=self.token_size, + has_bias=False, ) attn_procs[name].load_state_dict(weights) - sd.unet.set_attn_processor(attn_procs) - self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) + if self.sd_ref().is_pixart: + # we have to set them ourselves + transformer: Transformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"] + module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"] + self.adapter_modules = torch.nn.ModuleList([ + transformer.transformer_blocks[i].attn1.processor for i in range(len(transformer.transformer_blocks)) + ] + [ + transformer.transformer_blocks[i].attn2.processor for i in range(len(transformer.transformer_blocks)) + ]) + else: + sd.unet.set_attn_processor(attn_procs) + self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) # make a getter to see if is active @property diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 2fc4cc5e..61139f8d 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1567,10 +1567,15 @@ class StableDiffusion: named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True) unet_lr = unet_lr if unet_lr is not None else default_lr params = [] - for key, diffusers_key in ldm_diffusers_keymap.items(): - if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: - if named_params[diffusers_key].requires_grad: - params.append(named_params[diffusers_key]) + if self.is_pixart: + for param in named_params.values(): + if param.requires_grad: + params.append(param) + else: + for key, diffusers_key in ldm_diffusers_keymap.items(): + if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: + if named_params[diffusers_key].requires_grad: + params.append(named_params[diffusers_key]) param_data = {"params": params, "lr": unet_lr} trainable_parameters.append(param_data) print(f"Found {len(params)} trainable parameter in unet")