Work to omprove pixart training

This commit is contained in:
Jaret Burkett
2024-06-23 20:46:48 +00:00
parent 5d47244c57
commit 7165f2d25a
6 changed files with 65 additions and 18 deletions

View File

@@ -589,16 +589,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
return latest_path
def load_training_state_from_metadata(self, path):
meta = None
# if path is folder, then it is diffusers
if os.path.isdir(path):
meta_path = os.path.join(path, 'aitk_meta.yaml')
# load it
with open(meta_path, 'r') as f:
meta = yaml.load(f, Loader=yaml.FullLoader)
if os.path.exists(meta_path):
with open(meta_path, 'r') as f:
meta = yaml.load(f, Loader=yaml.FullLoader)
else:
meta = load_metadata_from_safetensors(path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
if meta is not None and 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
self.step_num = meta['training_info']['step']
if 'epoch' in meta['training_info']:
self.epoch_num = meta['training_info']['epoch']

View File

@@ -23,6 +23,5 @@ prodigyopt
controlnet_aux==0.0.7
python-dotenv
bitsandbytes
xformers
hf_transfer
lpips
lpips

View File

@@ -372,6 +372,8 @@ class ModelConfig:
# for text encoder quant. Only works with pixart currently
self.text_encoder_bits = kwargs.get('text_encoder_bits', 8) # 16, 8, 4
self.unet_path = kwargs.get("unet_path", None)
self.unet_sample_size = kwargs.get("unet_sample_size", None)
class ReferenceDatasetConfig:

View File

@@ -333,7 +333,7 @@ class IPAdapter(torch.nn.Module):
if not self.config.train_image_encoder:
# compile it
print('Compiling image encoder')
torch.compile(self.image_encoder, fullgraph=True)
#torch.compile(self.image_encoder, fullgraph=True)
self.input_size = self.image_encoder.config.image_size

View File

@@ -11,6 +11,7 @@ from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokeniz
from toolkit import train_tools
from toolkit.paths import REPOS_ROOT
from toolkit.prompt_utils import PromptEmbeds
from diffusers import Transformer2DModel
sys.path.append(REPOS_ROOT)
@@ -176,6 +177,7 @@ class TEAdapter(torch.nn.Module):
self.te_ref: weakref.ref = weakref.ref(te)
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
self.adapter_modules = []
is_pixart = sd.is_pixart
if self.adapter_ref().config.text_encoder_arch == "t5":
self.token_size = self.te_ref().config.d_model
@@ -195,9 +197,28 @@ class TEAdapter(torch.nn.Module):
}
module_idx = 0
attn_processors_list = list(sd.unet.attn_processors.keys())
for name in sd.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
# init adapter modules
attn_procs = {}
unet_sd = sd.unet.state_dict()
attn_processor_keys = []
if is_pixart:
transformer: Transformer2DModel = sd.unet
for i, module in transformer.transformer_blocks.named_children():
attn_processor_keys.append(f"transformer_blocks.{i}.attn1")
# cross attention
attn_processor_keys.append(f"transformer_blocks.{i}.attn2")
else:
attn_processor_keys = list(sd.unet.attn_processors.keys())
attn_processor_names = []
blocks = []
transformer_blocks = []
for name in attn_processor_keys:
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \
sd.unet.config['cross_attention_dim']
if name.startswith("mid_block"):
hidden_size = sd.unet.config['block_out_channels'][-1]
elif name.startswith("up_blocks"):
@@ -206,6 +227,8 @@ class TEAdapter(torch.nn.Module):
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = sd.unet.config['block_out_channels'][block_id]
elif name.startswith("transformer"):
hidden_size = sd.unet.config['cross_attention_dim']
else:
# they didnt have this, but would lead to undefined below
raise ValueError(f"unknown attn processor name: {name}")
@@ -244,6 +267,13 @@ class TEAdapter(torch.nn.Module):
"to_k_adapter.weight": to_k_adapter,
"to_v_adapter.weight": to_v_adapter,
}
if self.sd_ref().is_pixart:
# pixart is much more sensitive
weights = {
"to_k_adapter.weight": weights["to_k_adapter.weight"] * 0.01,
"to_v_adapter.weight": weights["to_v_adapter.weight"] * 0.01,
}
attn_procs[name] = TEAdapterAttnProcessor(
hidden_size=hidden_size,
@@ -256,8 +286,20 @@ class TEAdapter(torch.nn.Module):
)
attn_procs[name].load_state_dict(weights)
self.adapter_modules.append(attn_procs[name])
sd.unet.set_attn_processor(attn_procs)
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
if self.sd_ref().is_pixart:
# we have to set them ourselves
transformer: Transformer2DModel = sd.unet
for i, module in transformer.transformer_blocks.named_children():
module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"]
module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"]
self.adapter_modules = torch.nn.ModuleList(
[
transformer.transformer_blocks[i].attn2.processor for i in
range(len(transformer.transformer_blocks))
])
else:
sd.unet.set_attn_processor(attn_procs)
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
# make a getter to see if is active
@property

View File

@@ -38,9 +38,9 @@ import torch
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline
import diffusers
from diffusers import \
AutoencoderKL, \
@@ -185,7 +185,6 @@ class StableDiffusion:
}
if self.model_config.vae_path is not None:
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
@@ -331,11 +330,11 @@ class StableDiffusion:
if self.model_config.is_pixart_sigma:
# load the transformer only from the save
transformer = Transformer2DModel.from_pretrained(
model_path,
model_path if self.model_config.unet_path is None else self.model_config.unet_path,
torch_dtype=self.torch_dtype,
subfolder='transformer'
)
pipe: PixArtAlphaPipeline = PixArtSigmaPipeline.from_pretrained(
pipe: PixArtSigmaPipeline = PixArtSigmaPipeline.from_pretrained(
main_model_path,
transformer=transformer,
text_encoder=text_encoder,
@@ -357,6 +356,9 @@ class StableDiffusion:
device=self.device_torch,
**load_args
).to(self.device_torch)
if self.model_config.unet_sample_size is not None:
pipe.transformer.config.sample_size = self.model_config.unet_sample_size
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
flush()
@@ -592,7 +594,7 @@ class StableDiffusion:
**extra_args
)
elif self.is_pixart:
pipeline = PixArtAlphaPipeline(
pipeline = PixArtSigmaPipeline(
vae=self.vae,
transformer=self.unet,
text_encoder=self.text_encoder,
@@ -1243,7 +1245,7 @@ class StableDiffusion:
elif self.pipeline.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_256_BIN
else:
raise ValueError("Invalid sample size")
raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}")
orig_height, orig_width = height, width
height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)