mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Work to omprove pixart training
This commit is contained in:
@@ -589,16 +589,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
return latest_path
|
||||
|
||||
def load_training_state_from_metadata(self, path):
|
||||
meta = None
|
||||
# if path is folder, then it is diffusers
|
||||
if os.path.isdir(path):
|
||||
meta_path = os.path.join(path, 'aitk_meta.yaml')
|
||||
# load it
|
||||
with open(meta_path, 'r') as f:
|
||||
meta = yaml.load(f, Loader=yaml.FullLoader)
|
||||
if os.path.exists(meta_path):
|
||||
with open(meta_path, 'r') as f:
|
||||
meta = yaml.load(f, Loader=yaml.FullLoader)
|
||||
else:
|
||||
meta = load_metadata_from_safetensors(path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
|
||||
if meta is not None and 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
|
||||
self.step_num = meta['training_info']['step']
|
||||
if 'epoch' in meta['training_info']:
|
||||
self.epoch_num = meta['training_info']['epoch']
|
||||
|
||||
@@ -23,6 +23,5 @@ prodigyopt
|
||||
controlnet_aux==0.0.7
|
||||
python-dotenv
|
||||
bitsandbytes
|
||||
xformers
|
||||
hf_transfer
|
||||
lpips
|
||||
lpips
|
||||
|
||||
@@ -372,6 +372,8 @@ class ModelConfig:
|
||||
|
||||
# for text encoder quant. Only works with pixart currently
|
||||
self.text_encoder_bits = kwargs.get('text_encoder_bits', 8) # 16, 8, 4
|
||||
self.unet_path = kwargs.get("unet_path", None)
|
||||
self.unet_sample_size = kwargs.get("unet_sample_size", None)
|
||||
|
||||
|
||||
class ReferenceDatasetConfig:
|
||||
|
||||
@@ -333,7 +333,7 @@ class IPAdapter(torch.nn.Module):
|
||||
if not self.config.train_image_encoder:
|
||||
# compile it
|
||||
print('Compiling image encoder')
|
||||
torch.compile(self.image_encoder, fullgraph=True)
|
||||
#torch.compile(self.image_encoder, fullgraph=True)
|
||||
|
||||
self.input_size = self.image_encoder.config.image_size
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokeniz
|
||||
from toolkit import train_tools
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from diffusers import Transformer2DModel
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
|
||||
@@ -176,6 +177,7 @@ class TEAdapter(torch.nn.Module):
|
||||
self.te_ref: weakref.ref = weakref.ref(te)
|
||||
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
|
||||
self.adapter_modules = []
|
||||
is_pixart = sd.is_pixart
|
||||
|
||||
if self.adapter_ref().config.text_encoder_arch == "t5":
|
||||
self.token_size = self.te_ref().config.d_model
|
||||
@@ -195,9 +197,28 @@ class TEAdapter(torch.nn.Module):
|
||||
|
||||
}
|
||||
module_idx = 0
|
||||
attn_processors_list = list(sd.unet.attn_processors.keys())
|
||||
for name in sd.unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
|
||||
# init adapter modules
|
||||
attn_procs = {}
|
||||
unet_sd = sd.unet.state_dict()
|
||||
attn_processor_keys = []
|
||||
if is_pixart:
|
||||
transformer: Transformer2DModel = sd.unet
|
||||
for i, module in transformer.transformer_blocks.named_children():
|
||||
attn_processor_keys.append(f"transformer_blocks.{i}.attn1")
|
||||
|
||||
# cross attention
|
||||
attn_processor_keys.append(f"transformer_blocks.{i}.attn2")
|
||||
|
||||
else:
|
||||
attn_processor_keys = list(sd.unet.attn_processors.keys())
|
||||
|
||||
attn_processor_names = []
|
||||
|
||||
blocks = []
|
||||
transformer_blocks = []
|
||||
for name in attn_processor_keys:
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \
|
||||
sd.unet.config['cross_attention_dim']
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = sd.unet.config['block_out_channels'][-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
@@ -206,6 +227,8 @@ class TEAdapter(torch.nn.Module):
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = sd.unet.config['block_out_channels'][block_id]
|
||||
elif name.startswith("transformer"):
|
||||
hidden_size = sd.unet.config['cross_attention_dim']
|
||||
else:
|
||||
# they didnt have this, but would lead to undefined below
|
||||
raise ValueError(f"unknown attn processor name: {name}")
|
||||
@@ -244,6 +267,13 @@ class TEAdapter(torch.nn.Module):
|
||||
"to_k_adapter.weight": to_k_adapter,
|
||||
"to_v_adapter.weight": to_v_adapter,
|
||||
}
|
||||
|
||||
if self.sd_ref().is_pixart:
|
||||
# pixart is much more sensitive
|
||||
weights = {
|
||||
"to_k_adapter.weight": weights["to_k_adapter.weight"] * 0.01,
|
||||
"to_v_adapter.weight": weights["to_v_adapter.weight"] * 0.01,
|
||||
}
|
||||
|
||||
attn_procs[name] = TEAdapterAttnProcessor(
|
||||
hidden_size=hidden_size,
|
||||
@@ -256,8 +286,20 @@ class TEAdapter(torch.nn.Module):
|
||||
)
|
||||
attn_procs[name].load_state_dict(weights)
|
||||
self.adapter_modules.append(attn_procs[name])
|
||||
sd.unet.set_attn_processor(attn_procs)
|
||||
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
||||
if self.sd_ref().is_pixart:
|
||||
# we have to set them ourselves
|
||||
transformer: Transformer2DModel = sd.unet
|
||||
for i, module in transformer.transformer_blocks.named_children():
|
||||
module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"]
|
||||
module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"]
|
||||
self.adapter_modules = torch.nn.ModuleList(
|
||||
[
|
||||
transformer.transformer_blocks[i].attn2.processor for i in
|
||||
range(len(transformer.transformer_blocks))
|
||||
])
|
||||
else:
|
||||
sd.unet.set_attn_processor(attn_procs)
|
||||
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
||||
|
||||
# make a getter to see if is active
|
||||
@property
|
||||
|
||||
@@ -38,9 +38,9 @@ import torch
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
||||
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline
|
||||
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline
|
||||
import diffusers
|
||||
from diffusers import \
|
||||
AutoencoderKL, \
|
||||
@@ -185,7 +185,6 @@ class StableDiffusion:
|
||||
}
|
||||
if self.model_config.vae_path is not None:
|
||||
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
|
||||
|
||||
if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega:
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
@@ -331,11 +330,11 @@ class StableDiffusion:
|
||||
if self.model_config.is_pixart_sigma:
|
||||
# load the transformer only from the save
|
||||
transformer = Transformer2DModel.from_pretrained(
|
||||
model_path,
|
||||
model_path if self.model_config.unet_path is None else self.model_config.unet_path,
|
||||
torch_dtype=self.torch_dtype,
|
||||
subfolder='transformer'
|
||||
)
|
||||
pipe: PixArtAlphaPipeline = PixArtSigmaPipeline.from_pretrained(
|
||||
pipe: PixArtSigmaPipeline = PixArtSigmaPipeline.from_pretrained(
|
||||
main_model_path,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
@@ -357,6 +356,9 @@ class StableDiffusion:
|
||||
device=self.device_torch,
|
||||
**load_args
|
||||
).to(self.device_torch)
|
||||
|
||||
if self.model_config.unet_sample_size is not None:
|
||||
pipe.transformer.config.sample_size = self.model_config.unet_sample_size
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
flush()
|
||||
@@ -592,7 +594,7 @@ class StableDiffusion:
|
||||
**extra_args
|
||||
)
|
||||
elif self.is_pixart:
|
||||
pipeline = PixArtAlphaPipeline(
|
||||
pipeline = PixArtSigmaPipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.unet,
|
||||
text_encoder=self.text_encoder,
|
||||
@@ -1243,7 +1245,7 @@ class StableDiffusion:
|
||||
elif self.pipeline.transformer.config.sample_size == 32:
|
||||
aspect_ratio_bin = ASPECT_RATIO_256_BIN
|
||||
else:
|
||||
raise ValueError("Invalid sample size")
|
||||
raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}")
|
||||
orig_height, orig_width = height, width
|
||||
height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user