mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-20 22:39:03 +00:00
Work to omprove pixart training
This commit is contained in:
@@ -11,6 +11,7 @@ from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokeniz
|
||||
from toolkit import train_tools
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from diffusers import Transformer2DModel
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
|
||||
@@ -176,6 +177,7 @@ class TEAdapter(torch.nn.Module):
|
||||
self.te_ref: weakref.ref = weakref.ref(te)
|
||||
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
|
||||
self.adapter_modules = []
|
||||
is_pixart = sd.is_pixart
|
||||
|
||||
if self.adapter_ref().config.text_encoder_arch == "t5":
|
||||
self.token_size = self.te_ref().config.d_model
|
||||
@@ -195,9 +197,28 @@ class TEAdapter(torch.nn.Module):
|
||||
|
||||
}
|
||||
module_idx = 0
|
||||
attn_processors_list = list(sd.unet.attn_processors.keys())
|
||||
for name in sd.unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
|
||||
# init adapter modules
|
||||
attn_procs = {}
|
||||
unet_sd = sd.unet.state_dict()
|
||||
attn_processor_keys = []
|
||||
if is_pixart:
|
||||
transformer: Transformer2DModel = sd.unet
|
||||
for i, module in transformer.transformer_blocks.named_children():
|
||||
attn_processor_keys.append(f"transformer_blocks.{i}.attn1")
|
||||
|
||||
# cross attention
|
||||
attn_processor_keys.append(f"transformer_blocks.{i}.attn2")
|
||||
|
||||
else:
|
||||
attn_processor_keys = list(sd.unet.attn_processors.keys())
|
||||
|
||||
attn_processor_names = []
|
||||
|
||||
blocks = []
|
||||
transformer_blocks = []
|
||||
for name in attn_processor_keys:
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \
|
||||
sd.unet.config['cross_attention_dim']
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = sd.unet.config['block_out_channels'][-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
@@ -206,6 +227,8 @@ class TEAdapter(torch.nn.Module):
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = sd.unet.config['block_out_channels'][block_id]
|
||||
elif name.startswith("transformer"):
|
||||
hidden_size = sd.unet.config['cross_attention_dim']
|
||||
else:
|
||||
# they didnt have this, but would lead to undefined below
|
||||
raise ValueError(f"unknown attn processor name: {name}")
|
||||
@@ -244,6 +267,13 @@ class TEAdapter(torch.nn.Module):
|
||||
"to_k_adapter.weight": to_k_adapter,
|
||||
"to_v_adapter.weight": to_v_adapter,
|
||||
}
|
||||
|
||||
if self.sd_ref().is_pixart:
|
||||
# pixart is much more sensitive
|
||||
weights = {
|
||||
"to_k_adapter.weight": weights["to_k_adapter.weight"] * 0.01,
|
||||
"to_v_adapter.weight": weights["to_v_adapter.weight"] * 0.01,
|
||||
}
|
||||
|
||||
attn_procs[name] = TEAdapterAttnProcessor(
|
||||
hidden_size=hidden_size,
|
||||
@@ -256,8 +286,20 @@ class TEAdapter(torch.nn.Module):
|
||||
)
|
||||
attn_procs[name].load_state_dict(weights)
|
||||
self.adapter_modules.append(attn_procs[name])
|
||||
sd.unet.set_attn_processor(attn_procs)
|
||||
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
||||
if self.sd_ref().is_pixart:
|
||||
# we have to set them ourselves
|
||||
transformer: Transformer2DModel = sd.unet
|
||||
for i, module in transformer.transformer_blocks.named_children():
|
||||
module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"]
|
||||
module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"]
|
||||
self.adapter_modules = torch.nn.ModuleList(
|
||||
[
|
||||
transformer.transformer_blocks[i].attn2.processor for i in
|
||||
range(len(transformer.transformer_blocks))
|
||||
])
|
||||
else:
|
||||
sd.unet.set_attn_processor(attn_procs)
|
||||
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
||||
|
||||
# make a getter to see if is active
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user