Work to omprove pixart training

This commit is contained in:
Jaret Burkett
2024-06-23 20:46:48 +00:00
parent 5d47244c57
commit 7165f2d25a
6 changed files with 65 additions and 18 deletions

View File

@@ -11,6 +11,7 @@ from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokeniz
from toolkit import train_tools
from toolkit.paths import REPOS_ROOT
from toolkit.prompt_utils import PromptEmbeds
from diffusers import Transformer2DModel
sys.path.append(REPOS_ROOT)
@@ -176,6 +177,7 @@ class TEAdapter(torch.nn.Module):
self.te_ref: weakref.ref = weakref.ref(te)
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
self.adapter_modules = []
is_pixart = sd.is_pixart
if self.adapter_ref().config.text_encoder_arch == "t5":
self.token_size = self.te_ref().config.d_model
@@ -195,9 +197,28 @@ class TEAdapter(torch.nn.Module):
}
module_idx = 0
attn_processors_list = list(sd.unet.attn_processors.keys())
for name in sd.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
# init adapter modules
attn_procs = {}
unet_sd = sd.unet.state_dict()
attn_processor_keys = []
if is_pixart:
transformer: Transformer2DModel = sd.unet
for i, module in transformer.transformer_blocks.named_children():
attn_processor_keys.append(f"transformer_blocks.{i}.attn1")
# cross attention
attn_processor_keys.append(f"transformer_blocks.{i}.attn2")
else:
attn_processor_keys = list(sd.unet.attn_processors.keys())
attn_processor_names = []
blocks = []
transformer_blocks = []
for name in attn_processor_keys:
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \
sd.unet.config['cross_attention_dim']
if name.startswith("mid_block"):
hidden_size = sd.unet.config['block_out_channels'][-1]
elif name.startswith("up_blocks"):
@@ -206,6 +227,8 @@ class TEAdapter(torch.nn.Module):
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = sd.unet.config['block_out_channels'][block_id]
elif name.startswith("transformer"):
hidden_size = sd.unet.config['cross_attention_dim']
else:
# they didnt have this, but would lead to undefined below
raise ValueError(f"unknown attn processor name: {name}")
@@ -244,6 +267,13 @@ class TEAdapter(torch.nn.Module):
"to_k_adapter.weight": to_k_adapter,
"to_v_adapter.weight": to_v_adapter,
}
if self.sd_ref().is_pixart:
# pixart is much more sensitive
weights = {
"to_k_adapter.weight": weights["to_k_adapter.weight"] * 0.01,
"to_v_adapter.weight": weights["to_v_adapter.weight"] * 0.01,
}
attn_procs[name] = TEAdapterAttnProcessor(
hidden_size=hidden_size,
@@ -256,8 +286,20 @@ class TEAdapter(torch.nn.Module):
)
attn_procs[name].load_state_dict(weights)
self.adapter_modules.append(attn_procs[name])
sd.unet.set_attn_processor(attn_procs)
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
if self.sd_ref().is_pixart:
# we have to set them ourselves
transformer: Transformer2DModel = sd.unet
for i, module in transformer.transformer_blocks.named_children():
module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"]
module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"]
self.adapter_modules = torch.nn.ModuleList(
[
transformer.transformer_blocks[i].attn2.processor for i in
range(len(transformer.transformer_blocks))
])
else:
sd.unet.set_attn_processor(attn_procs)
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
# make a getter to see if is active
@property