Added training for pixart-a

This commit is contained in:
Jaret Burkett
2024-02-13 16:00:04 -07:00
parent 4ec4025cbb
commit 93b52932c1
10 changed files with 288 additions and 24 deletions

View File

@@ -11,6 +11,7 @@ from collections import OrderedDict
import yaml
from PIL import Image
from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from safetensors.torch import save_file, load_file
from torch.nn import Parameter
@@ -43,6 +44,8 @@ import diffusers
from diffusers import \
AutoencoderKL, \
UNet2DConditionModel
from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler
from transformers import T5EncoderModel
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
@@ -121,7 +124,7 @@ class StableDiffusion:
self.device_state = None
self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline']
self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline']
self.vae: Union[None, 'AutoencoderKL']
self.unet: Union[None, 'UNet2DConditionModel']
self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
@@ -142,6 +145,7 @@ class StableDiffusion:
self.is_v2 = model_config.is_v2
self.is_ssd = model_config.is_ssd
self.is_vega = model_config.is_vega
self.is_pixart = model_config.is_pixart
self.use_text_encoder_1 = model_config.use_text_encoder_1
self.use_text_encoder_2 = model_config.use_text_encoder_2
@@ -157,7 +161,9 @@ class StableDiffusion:
scheduler = get_sampler(
'ddpm', {
"prediction_type": self.prediction_type,
})
},
'sd' if not self.is_pixart else 'pixart'
)
self.noise_scheduler = scheduler
# move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why
@@ -227,7 +233,33 @@ class StableDiffusion:
te1_state_dict['text_projection.weight'] = replacement_weight.to(self.device_torch, dtype=dtype)
flush()
print("Injecting alt weights")
elif self.model_config.is_pixart:
# load the TE in 8bit mode
text_encoder = T5EncoderModel.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
subfolder="text_encoder",
load_in_8bit=True,
device_map="auto",
torch_dtype=self.torch_dtype,
)
# replace the to function with a no-op since it throws an error instead of a warning
text_encoder.to = lambda *args, **kwargs: None
pipe: PixArtAlphaPipeline = PixArtAlphaPipeline.from_pretrained(
model_path,
text_encoder=text_encoder,
dtype=dtype,
device=self.device_torch,
**load_args
).to(self.device_torch)
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
flush()
# text_encoder = pipe.text_encoder
# text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
tokenizer = pipe.tokenizer
else:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
@@ -273,7 +305,11 @@ class StableDiffusion:
# add hacks to unet to help training
# pipe.unet = prepare_unet_for_training(pipe.unet)
self.unet: 'UNet2DConditionModel' = pipe.unet
if self.is_pixart:
# pixart doesnt use a unet
self.unet = pipe.transformer
else:
self.unet: 'UNet2DConditionModel' = pipe.unet
self.vae: 'AutoencoderKL' = pipe.vae.to(self.device_torch, dtype=dtype)
self.vae.eval()
self.vae.requires_grad_(False)
@@ -381,7 +417,8 @@ class StableDiffusion:
sampler,
{
"prediction_type": self.prediction_type,
}
},
'sd' if not self.is_pixart else 'pixart'
)
try:
@@ -425,6 +462,16 @@ class StableDiffusion:
**extra_args
).to(self.device_torch)
pipeline.watermark = None
elif self.is_pixart:
pipeline = PixArtAlphaPipeline(
vae=self.vae,
transformer=self.unet,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=noise_scheduler,
**extra_args
).to(self.device_torch)
else:
pipeline = Pipe(
vae=self.vae,
@@ -615,6 +662,23 @@ class StableDiffusion:
latents=gen_config.latents,
**extra
).images[0]
elif self.is_pixart:
# needs attention masks for some reason
img = pipeline(
prompt=None,
prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype),
negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype),
negative_prompt=None,
# negative_prompt=gen_config.negative_prompt,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
**extra
).images[0]
else:
img = pipeline(
# prompt=gen_config.prompt,
@@ -1005,12 +1069,53 @@ class StableDiffusion:
f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
# predict the noise residual
noise_pred = self.unet(
latent_model_input.to(self.device_torch, self.torch_dtype),
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
**kwargs,
).sample
if self.is_pixart:
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
batch_size, ch, h, w = list(latents.shape)
height = h * VAE_SCALE_FACTOR
width = w * VAE_SCALE_FACTOR
aspect_ratio_bin = (
ASPECT_RATIO_1024_BIN if self.unet.config.sample_size == 128 else ASPECT_RATIO_512_BIN
)
orig_height, orig_width = height, width
height, width = self.pipeline.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
if self.unet.config.sample_size == 128:
resolution = torch.tensor([height, width]).repeat(batch_size, 1)
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1)
resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch)
aspect_ratio = aspect_ratio.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch)
if do_classifier_free_guidance:
resolution = torch.cat([resolution, resolution], dim=0)
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
noise_pred = self.unet(
latent_model_input.to(self.device_torch, self.torch_dtype),
encoder_hidden_states=text_embeddings.text_embeds,
encoder_attention_mask=text_embeddings.attention_mask,
timestep=timestep,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
**kwargs
)[0]
# learned sigma
if self.unet.config.out_channels // 2 == self.unet.config.in_channels:
noise_pred = noise_pred.chunk(2, dim=1)[0]
else:
noise_pred = noise_pred
else:
noise_pred = self.unet(
latent_model_input.to(self.device_torch, self.torch_dtype),
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
**kwargs,
).sample
if do_classifier_free_guidance:
# perform guidance
@@ -1142,6 +1247,20 @@ class StableDiffusion:
dropout_prob=dropout_prob,
)
)
elif self.is_pixart:
embeds, attention_mask = train_tools.encode_prompts_pixart(
self.tokenizer,
self.text_encoder,
prompt,
truncate=not long_prompts,
max_length=max_length,
dropout_prob=dropout_prob
)
return PromptEmbeds(
embeds,
attention_mask=attention_mask,
)
else:
return PromptEmbeds(
train_tools.encode_prompts(
@@ -1489,6 +1608,11 @@ class StableDiffusion:
def save_device_state(self):
# saves the current device state for all modules
# this is useful for when we want to alter the state and restore it
if self.is_pixart:
unet_has_grad = self.unet.proj_out.weight.requires_grad
else:
unet_has_grad = self.unet.conv_in.weight.requires_grad
self.device_state = {
**empty_preset,
'vae': {
@@ -1498,7 +1622,7 @@ class StableDiffusion:
'unet': {
'training': self.unet.training,
'device': self.unet.device,
'requires_grad': self.unet.conv_in.weight.requires_grad,
'requires_grad': unet_has_grad,
},
}
if isinstance(self.text_encoder, list):
@@ -1511,10 +1635,15 @@ class StableDiffusion:
'requires_grad': encoder.text_model.final_layer_norm.weight.requires_grad
})
else:
if isinstance(self.text_encoder, T5EncoderModel):
te_has_grad = self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
else:
te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad
self.device_state['text_encoder'] = {
'training': self.text_encoder.training,
'device': self.text_encoder.device,
'requires_grad': self.text_encoder.text_model.final_layer_norm.weight.requires_grad
'requires_grad': te_has_grad
}
if self.adapter is not None:
if isinstance(self.adapter, IPAdapter):