mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Work to omprove pixart training
This commit is contained in:
@@ -589,16 +589,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
return latest_path
|
return latest_path
|
||||||
|
|
||||||
def load_training_state_from_metadata(self, path):
|
def load_training_state_from_metadata(self, path):
|
||||||
|
meta = None
|
||||||
# if path is folder, then it is diffusers
|
# if path is folder, then it is diffusers
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
meta_path = os.path.join(path, 'aitk_meta.yaml')
|
meta_path = os.path.join(path, 'aitk_meta.yaml')
|
||||||
# load it
|
# load it
|
||||||
with open(meta_path, 'r') as f:
|
if os.path.exists(meta_path):
|
||||||
meta = yaml.load(f, Loader=yaml.FullLoader)
|
with open(meta_path, 'r') as f:
|
||||||
|
meta = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
else:
|
else:
|
||||||
meta = load_metadata_from_safetensors(path)
|
meta = load_metadata_from_safetensors(path)
|
||||||
# if 'training_info' in Orderdict keys
|
# if 'training_info' in Orderdict keys
|
||||||
if 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
|
if meta is not None and 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
|
||||||
self.step_num = meta['training_info']['step']
|
self.step_num = meta['training_info']['step']
|
||||||
if 'epoch' in meta['training_info']:
|
if 'epoch' in meta['training_info']:
|
||||||
self.epoch_num = meta['training_info']['epoch']
|
self.epoch_num = meta['training_info']['epoch']
|
||||||
|
|||||||
@@ -23,6 +23,5 @@ prodigyopt
|
|||||||
controlnet_aux==0.0.7
|
controlnet_aux==0.0.7
|
||||||
python-dotenv
|
python-dotenv
|
||||||
bitsandbytes
|
bitsandbytes
|
||||||
xformers
|
|
||||||
hf_transfer
|
hf_transfer
|
||||||
lpips
|
lpips
|
||||||
|
|||||||
@@ -372,6 +372,8 @@ class ModelConfig:
|
|||||||
|
|
||||||
# for text encoder quant. Only works with pixart currently
|
# for text encoder quant. Only works with pixart currently
|
||||||
self.text_encoder_bits = kwargs.get('text_encoder_bits', 8) # 16, 8, 4
|
self.text_encoder_bits = kwargs.get('text_encoder_bits', 8) # 16, 8, 4
|
||||||
|
self.unet_path = kwargs.get("unet_path", None)
|
||||||
|
self.unet_sample_size = kwargs.get("unet_sample_size", None)
|
||||||
|
|
||||||
|
|
||||||
class ReferenceDatasetConfig:
|
class ReferenceDatasetConfig:
|
||||||
|
|||||||
@@ -333,7 +333,7 @@ class IPAdapter(torch.nn.Module):
|
|||||||
if not self.config.train_image_encoder:
|
if not self.config.train_image_encoder:
|
||||||
# compile it
|
# compile it
|
||||||
print('Compiling image encoder')
|
print('Compiling image encoder')
|
||||||
torch.compile(self.image_encoder, fullgraph=True)
|
#torch.compile(self.image_encoder, fullgraph=True)
|
||||||
|
|
||||||
self.input_size = self.image_encoder.config.image_size
|
self.input_size = self.image_encoder.config.image_size
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokeniz
|
|||||||
from toolkit import train_tools
|
from toolkit import train_tools
|
||||||
from toolkit.paths import REPOS_ROOT
|
from toolkit.paths import REPOS_ROOT
|
||||||
from toolkit.prompt_utils import PromptEmbeds
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
|
from diffusers import Transformer2DModel
|
||||||
|
|
||||||
sys.path.append(REPOS_ROOT)
|
sys.path.append(REPOS_ROOT)
|
||||||
|
|
||||||
@@ -176,6 +177,7 @@ class TEAdapter(torch.nn.Module):
|
|||||||
self.te_ref: weakref.ref = weakref.ref(te)
|
self.te_ref: weakref.ref = weakref.ref(te)
|
||||||
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
|
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
|
||||||
self.adapter_modules = []
|
self.adapter_modules = []
|
||||||
|
is_pixart = sd.is_pixart
|
||||||
|
|
||||||
if self.adapter_ref().config.text_encoder_arch == "t5":
|
if self.adapter_ref().config.text_encoder_arch == "t5":
|
||||||
self.token_size = self.te_ref().config.d_model
|
self.token_size = self.te_ref().config.d_model
|
||||||
@@ -195,9 +197,28 @@ class TEAdapter(torch.nn.Module):
|
|||||||
|
|
||||||
}
|
}
|
||||||
module_idx = 0
|
module_idx = 0
|
||||||
attn_processors_list = list(sd.unet.attn_processors.keys())
|
# init adapter modules
|
||||||
for name in sd.unet.attn_processors.keys():
|
attn_procs = {}
|
||||||
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
|
unet_sd = sd.unet.state_dict()
|
||||||
|
attn_processor_keys = []
|
||||||
|
if is_pixart:
|
||||||
|
transformer: Transformer2DModel = sd.unet
|
||||||
|
for i, module in transformer.transformer_blocks.named_children():
|
||||||
|
attn_processor_keys.append(f"transformer_blocks.{i}.attn1")
|
||||||
|
|
||||||
|
# cross attention
|
||||||
|
attn_processor_keys.append(f"transformer_blocks.{i}.attn2")
|
||||||
|
|
||||||
|
else:
|
||||||
|
attn_processor_keys = list(sd.unet.attn_processors.keys())
|
||||||
|
|
||||||
|
attn_processor_names = []
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
transformer_blocks = []
|
||||||
|
for name in attn_processor_keys:
|
||||||
|
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \
|
||||||
|
sd.unet.config['cross_attention_dim']
|
||||||
if name.startswith("mid_block"):
|
if name.startswith("mid_block"):
|
||||||
hidden_size = sd.unet.config['block_out_channels'][-1]
|
hidden_size = sd.unet.config['block_out_channels'][-1]
|
||||||
elif name.startswith("up_blocks"):
|
elif name.startswith("up_blocks"):
|
||||||
@@ -206,6 +227,8 @@ class TEAdapter(torch.nn.Module):
|
|||||||
elif name.startswith("down_blocks"):
|
elif name.startswith("down_blocks"):
|
||||||
block_id = int(name[len("down_blocks.")])
|
block_id = int(name[len("down_blocks.")])
|
||||||
hidden_size = sd.unet.config['block_out_channels'][block_id]
|
hidden_size = sd.unet.config['block_out_channels'][block_id]
|
||||||
|
elif name.startswith("transformer"):
|
||||||
|
hidden_size = sd.unet.config['cross_attention_dim']
|
||||||
else:
|
else:
|
||||||
# they didnt have this, but would lead to undefined below
|
# they didnt have this, but would lead to undefined below
|
||||||
raise ValueError(f"unknown attn processor name: {name}")
|
raise ValueError(f"unknown attn processor name: {name}")
|
||||||
@@ -244,6 +267,13 @@ class TEAdapter(torch.nn.Module):
|
|||||||
"to_k_adapter.weight": to_k_adapter,
|
"to_k_adapter.weight": to_k_adapter,
|
||||||
"to_v_adapter.weight": to_v_adapter,
|
"to_v_adapter.weight": to_v_adapter,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.sd_ref().is_pixart:
|
||||||
|
# pixart is much more sensitive
|
||||||
|
weights = {
|
||||||
|
"to_k_adapter.weight": weights["to_k_adapter.weight"] * 0.01,
|
||||||
|
"to_v_adapter.weight": weights["to_v_adapter.weight"] * 0.01,
|
||||||
|
}
|
||||||
|
|
||||||
attn_procs[name] = TEAdapterAttnProcessor(
|
attn_procs[name] = TEAdapterAttnProcessor(
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
@@ -256,8 +286,20 @@ class TEAdapter(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
attn_procs[name].load_state_dict(weights)
|
attn_procs[name].load_state_dict(weights)
|
||||||
self.adapter_modules.append(attn_procs[name])
|
self.adapter_modules.append(attn_procs[name])
|
||||||
sd.unet.set_attn_processor(attn_procs)
|
if self.sd_ref().is_pixart:
|
||||||
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
# we have to set them ourselves
|
||||||
|
transformer: Transformer2DModel = sd.unet
|
||||||
|
for i, module in transformer.transformer_blocks.named_children():
|
||||||
|
module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"]
|
||||||
|
module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"]
|
||||||
|
self.adapter_modules = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
transformer.transformer_blocks[i].attn2.processor for i in
|
||||||
|
range(len(transformer.transformer_blocks))
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
sd.unet.set_attn_processor(attn_procs)
|
||||||
|
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
||||||
|
|
||||||
# make a getter to see if is active
|
# make a getter to see if is active
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -38,9 +38,9 @@ import torch
|
|||||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
|
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
|
||||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
|
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
|
||||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
||||||
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline
|
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline
|
||||||
import diffusers
|
import diffusers
|
||||||
from diffusers import \
|
from diffusers import \
|
||||||
AutoencoderKL, \
|
AutoencoderKL, \
|
||||||
@@ -185,7 +185,6 @@ class StableDiffusion:
|
|||||||
}
|
}
|
||||||
if self.model_config.vae_path is not None:
|
if self.model_config.vae_path is not None:
|
||||||
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
|
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
|
||||||
|
|
||||||
if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega:
|
if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega:
|
||||||
if self.custom_pipeline is not None:
|
if self.custom_pipeline is not None:
|
||||||
pipln = self.custom_pipeline
|
pipln = self.custom_pipeline
|
||||||
@@ -331,11 +330,11 @@ class StableDiffusion:
|
|||||||
if self.model_config.is_pixart_sigma:
|
if self.model_config.is_pixart_sigma:
|
||||||
# load the transformer only from the save
|
# load the transformer only from the save
|
||||||
transformer = Transformer2DModel.from_pretrained(
|
transformer = Transformer2DModel.from_pretrained(
|
||||||
model_path,
|
model_path if self.model_config.unet_path is None else self.model_config.unet_path,
|
||||||
torch_dtype=self.torch_dtype,
|
torch_dtype=self.torch_dtype,
|
||||||
subfolder='transformer'
|
subfolder='transformer'
|
||||||
)
|
)
|
||||||
pipe: PixArtAlphaPipeline = PixArtSigmaPipeline.from_pretrained(
|
pipe: PixArtSigmaPipeline = PixArtSigmaPipeline.from_pretrained(
|
||||||
main_model_path,
|
main_model_path,
|
||||||
transformer=transformer,
|
transformer=transformer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
@@ -357,6 +356,9 @@ class StableDiffusion:
|
|||||||
device=self.device_torch,
|
device=self.device_torch,
|
||||||
**load_args
|
**load_args
|
||||||
).to(self.device_torch)
|
).to(self.device_torch)
|
||||||
|
|
||||||
|
if self.model_config.unet_sample_size is not None:
|
||||||
|
pipe.transformer.config.sample_size = self.model_config.unet_sample_size
|
||||||
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
flush()
|
flush()
|
||||||
@@ -592,7 +594,7 @@ class StableDiffusion:
|
|||||||
**extra_args
|
**extra_args
|
||||||
)
|
)
|
||||||
elif self.is_pixart:
|
elif self.is_pixart:
|
||||||
pipeline = PixArtAlphaPipeline(
|
pipeline = PixArtSigmaPipeline(
|
||||||
vae=self.vae,
|
vae=self.vae,
|
||||||
transformer=self.unet,
|
transformer=self.unet,
|
||||||
text_encoder=self.text_encoder,
|
text_encoder=self.text_encoder,
|
||||||
@@ -1243,7 +1245,7 @@ class StableDiffusion:
|
|||||||
elif self.pipeline.transformer.config.sample_size == 32:
|
elif self.pipeline.transformer.config.sample_size == 32:
|
||||||
aspect_ratio_bin = ASPECT_RATIO_256_BIN
|
aspect_ratio_bin = ASPECT_RATIO_256_BIN
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid sample size")
|
raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}")
|
||||||
orig_height, orig_width = height, width
|
orig_height, orig_width = height, width
|
||||||
height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user