Work to omprove pixart training

This commit is contained in:
Jaret Burkett
2024-06-23 20:46:48 +00:00
parent 5d47244c57
commit 7165f2d25a
6 changed files with 65 additions and 18 deletions

View File

@@ -589,16 +589,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
return latest_path return latest_path
def load_training_state_from_metadata(self, path): def load_training_state_from_metadata(self, path):
meta = None
# if path is folder, then it is diffusers # if path is folder, then it is diffusers
if os.path.isdir(path): if os.path.isdir(path):
meta_path = os.path.join(path, 'aitk_meta.yaml') meta_path = os.path.join(path, 'aitk_meta.yaml')
# load it # load it
with open(meta_path, 'r') as f: if os.path.exists(meta_path):
meta = yaml.load(f, Loader=yaml.FullLoader) with open(meta_path, 'r') as f:
meta = yaml.load(f, Loader=yaml.FullLoader)
else: else:
meta = load_metadata_from_safetensors(path) meta = load_metadata_from_safetensors(path)
# if 'training_info' in Orderdict keys # if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None: if meta is not None and 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
self.step_num = meta['training_info']['step'] self.step_num = meta['training_info']['step']
if 'epoch' in meta['training_info']: if 'epoch' in meta['training_info']:
self.epoch_num = meta['training_info']['epoch'] self.epoch_num = meta['training_info']['epoch']

View File

@@ -23,6 +23,5 @@ prodigyopt
controlnet_aux==0.0.7 controlnet_aux==0.0.7
python-dotenv python-dotenv
bitsandbytes bitsandbytes
xformers
hf_transfer hf_transfer
lpips lpips

View File

@@ -372,6 +372,8 @@ class ModelConfig:
# for text encoder quant. Only works with pixart currently # for text encoder quant. Only works with pixart currently
self.text_encoder_bits = kwargs.get('text_encoder_bits', 8) # 16, 8, 4 self.text_encoder_bits = kwargs.get('text_encoder_bits', 8) # 16, 8, 4
self.unet_path = kwargs.get("unet_path", None)
self.unet_sample_size = kwargs.get("unet_sample_size", None)
class ReferenceDatasetConfig: class ReferenceDatasetConfig:

View File

@@ -333,7 +333,7 @@ class IPAdapter(torch.nn.Module):
if not self.config.train_image_encoder: if not self.config.train_image_encoder:
# compile it # compile it
print('Compiling image encoder') print('Compiling image encoder')
torch.compile(self.image_encoder, fullgraph=True) #torch.compile(self.image_encoder, fullgraph=True)
self.input_size = self.image_encoder.config.image_size self.input_size = self.image_encoder.config.image_size

View File

@@ -11,6 +11,7 @@ from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokeniz
from toolkit import train_tools from toolkit import train_tools
from toolkit.paths import REPOS_ROOT from toolkit.paths import REPOS_ROOT
from toolkit.prompt_utils import PromptEmbeds from toolkit.prompt_utils import PromptEmbeds
from diffusers import Transformer2DModel
sys.path.append(REPOS_ROOT) sys.path.append(REPOS_ROOT)
@@ -176,6 +177,7 @@ class TEAdapter(torch.nn.Module):
self.te_ref: weakref.ref = weakref.ref(te) self.te_ref: weakref.ref = weakref.ref(te)
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer) self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
self.adapter_modules = [] self.adapter_modules = []
is_pixart = sd.is_pixart
if self.adapter_ref().config.text_encoder_arch == "t5": if self.adapter_ref().config.text_encoder_arch == "t5":
self.token_size = self.te_ref().config.d_model self.token_size = self.te_ref().config.d_model
@@ -195,9 +197,28 @@ class TEAdapter(torch.nn.Module):
} }
module_idx = 0 module_idx = 0
attn_processors_list = list(sd.unet.attn_processors.keys()) # init adapter modules
for name in sd.unet.attn_processors.keys(): attn_procs = {}
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim'] unet_sd = sd.unet.state_dict()
attn_processor_keys = []
if is_pixart:
transformer: Transformer2DModel = sd.unet
for i, module in transformer.transformer_blocks.named_children():
attn_processor_keys.append(f"transformer_blocks.{i}.attn1")
# cross attention
attn_processor_keys.append(f"transformer_blocks.{i}.attn2")
else:
attn_processor_keys = list(sd.unet.attn_processors.keys())
attn_processor_names = []
blocks = []
transformer_blocks = []
for name in attn_processor_keys:
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \
sd.unet.config['cross_attention_dim']
if name.startswith("mid_block"): if name.startswith("mid_block"):
hidden_size = sd.unet.config['block_out_channels'][-1] hidden_size = sd.unet.config['block_out_channels'][-1]
elif name.startswith("up_blocks"): elif name.startswith("up_blocks"):
@@ -206,6 +227,8 @@ class TEAdapter(torch.nn.Module):
elif name.startswith("down_blocks"): elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = sd.unet.config['block_out_channels'][block_id] hidden_size = sd.unet.config['block_out_channels'][block_id]
elif name.startswith("transformer"):
hidden_size = sd.unet.config['cross_attention_dim']
else: else:
# they didnt have this, but would lead to undefined below # they didnt have this, but would lead to undefined below
raise ValueError(f"unknown attn processor name: {name}") raise ValueError(f"unknown attn processor name: {name}")
@@ -244,6 +267,13 @@ class TEAdapter(torch.nn.Module):
"to_k_adapter.weight": to_k_adapter, "to_k_adapter.weight": to_k_adapter,
"to_v_adapter.weight": to_v_adapter, "to_v_adapter.weight": to_v_adapter,
} }
if self.sd_ref().is_pixart:
# pixart is much more sensitive
weights = {
"to_k_adapter.weight": weights["to_k_adapter.weight"] * 0.01,
"to_v_adapter.weight": weights["to_v_adapter.weight"] * 0.01,
}
attn_procs[name] = TEAdapterAttnProcessor( attn_procs[name] = TEAdapterAttnProcessor(
hidden_size=hidden_size, hidden_size=hidden_size,
@@ -256,8 +286,20 @@ class TEAdapter(torch.nn.Module):
) )
attn_procs[name].load_state_dict(weights) attn_procs[name].load_state_dict(weights)
self.adapter_modules.append(attn_procs[name]) self.adapter_modules.append(attn_procs[name])
sd.unet.set_attn_processor(attn_procs) if self.sd_ref().is_pixart:
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) # we have to set them ourselves
transformer: Transformer2DModel = sd.unet
for i, module in transformer.transformer_blocks.named_children():
module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"]
module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"]
self.adapter_modules = torch.nn.ModuleList(
[
transformer.transformer_blocks[i].attn2.processor for i in
range(len(transformer.transformer_blocks))
])
else:
sd.unet.set_attn_processor(attn_procs)
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
# make a getter to see if is active # make a getter to see if is active
@property @property

View File

@@ -38,9 +38,9 @@ import torch
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \ StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline
import diffusers import diffusers
from diffusers import \ from diffusers import \
AutoencoderKL, \ AutoencoderKL, \
@@ -185,7 +185,6 @@ class StableDiffusion:
} }
if self.model_config.vae_path is not None: if self.model_config.vae_path is not None:
load_args['vae'] = load_vae(self.model_config.vae_path, dtype) load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega: if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega:
if self.custom_pipeline is not None: if self.custom_pipeline is not None:
pipln = self.custom_pipeline pipln = self.custom_pipeline
@@ -331,11 +330,11 @@ class StableDiffusion:
if self.model_config.is_pixart_sigma: if self.model_config.is_pixart_sigma:
# load the transformer only from the save # load the transformer only from the save
transformer = Transformer2DModel.from_pretrained( transformer = Transformer2DModel.from_pretrained(
model_path, model_path if self.model_config.unet_path is None else self.model_config.unet_path,
torch_dtype=self.torch_dtype, torch_dtype=self.torch_dtype,
subfolder='transformer' subfolder='transformer'
) )
pipe: PixArtAlphaPipeline = PixArtSigmaPipeline.from_pretrained( pipe: PixArtSigmaPipeline = PixArtSigmaPipeline.from_pretrained(
main_model_path, main_model_path,
transformer=transformer, transformer=transformer,
text_encoder=text_encoder, text_encoder=text_encoder,
@@ -357,6 +356,9 @@ class StableDiffusion:
device=self.device_torch, device=self.device_torch,
**load_args **load_args
).to(self.device_torch) ).to(self.device_torch)
if self.model_config.unet_sample_size is not None:
pipe.transformer.config.sample_size = self.model_config.unet_sample_size
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
flush() flush()
@@ -592,7 +594,7 @@ class StableDiffusion:
**extra_args **extra_args
) )
elif self.is_pixart: elif self.is_pixart:
pipeline = PixArtAlphaPipeline( pipeline = PixArtSigmaPipeline(
vae=self.vae, vae=self.vae,
transformer=self.unet, transformer=self.unet,
text_encoder=self.text_encoder, text_encoder=self.text_encoder,
@@ -1243,7 +1245,7 @@ class StableDiffusion:
elif self.pipeline.transformer.config.sample_size == 32: elif self.pipeline.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_256_BIN aspect_ratio_bin = ASPECT_RATIO_256_BIN
else: else:
raise ValueError("Invalid sample size") raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}")
orig_height, orig_width = height, width orig_height, orig_width = height, width
height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)